Skip to content
Snippets Groups Projects

add script for tuning demixing

Merged Sarod Yatawatta requested to merge RAP-90 into master
@@ -38,6 +38,8 @@ class DemixingTuner:
self.timesec=timesec
self.Nf=Nf
self.skymodel=skymodel
# max iterations to consider
self.maxiter=[10, 20, 30, 40, 50]
if path_DP3:
self.DP3=path_DP3
if path_LINC:
@@ -51,6 +53,14 @@ class DemixingTuner:
self.B=0
# frequencies
self.freqlist=None
# bandwidth (of each subband)
self.bandwidth=0
# best bandwidth (freq resolution) to use
self.bandwidth_to_use=0
# best value for maximum iterations to use
self.maxiter_to_use=0
# number of channels (of each subband)
self.nchan=0
# target direction
self.ra0=0
self.dec0=0
@@ -67,6 +77,9 @@ class DemixingTuner:
self.low_elevation_limit=5
# outlier (Patch) names
self.ateam_names=None
self.chosen_patches=None
# 1: only select to minimum number of sources, increase to include more
self.inclusiveness=2
self.target_skymodel='./target.sky.txt'
self.final_skymodel='./combined.sky.txt'
@@ -75,6 +88,10 @@ class DemixingTuner:
# sqrt(residual sum of squares)
self.RSS=None
self.AIC=None
self.RSS_freq=None
self.AIC_freq=None
self.RSS_iter=None
self.AIC_iter=None
def extract_dataset(self):
self.mslist.sort()
@@ -145,6 +162,11 @@ class DemixingTuner:
self.freqlist[ci]=ch0[0,0]
tf.close()
tf=ctab.table(self.extracted_ms[0]+'/SPECTRAL_WINDOW',readonly=True)
self.bandwidth=tf.getcol('TOTAL_BANDWIDTH')[0]
self.nchan=tf.getcol('NUM_CHAN')[0]
tf.close()
# get antennas
ant=ctab.table(self.extracted_ms[0]+'/ANTENNA',readonly=True)
self.N=ant.nrows()
@@ -228,7 +250,7 @@ class DemixingTuner:
# merge target with outliers given by
# category index 'index'
def create_parset(self,index):
def create_parset_dir(self,index):
chosen_dirs=self.scalar_to_kvec(index)
chosen_patches=itertools.compress(self.ateam_names,chosen_dirs)
parset_name=str(index)+'_'+self.parset_demix
@@ -238,8 +260,6 @@ class DemixingTuner:
+'demix.baseline=[CR]S*&\n'
+'demix.ignoretarget=False\n'
+'demix.targetsource=\"CENTER\"\n'
#+'demix.demixtimestep=10\n'
#+'demix.demixfreqstep=1\n'
+'demix.demixtimeresolution=10\n'
+'demix.demixfreqresolution=50kHz\n'
+'demix.uselbfgssolver=true\n'
@@ -264,6 +284,71 @@ class DemixingTuner:
bbsdem.close()
return parset_name
def create_parset_freq(self,channel_index,channels):
bandwidth=self.bandwidth/self.nchan*channels[channel_index]/1e3
parset_name=str(channel_index)+'_'+self.parset_demix
bbsdem=open(parset_name,'w+')
bbsdem.write('steps=[demix]\n'
+'demix.type=demixer\n'
+'demix.baseline=[CR]S*&\n'
+'demix.ignoretarget=False\n'
+'demix.targetsource=\"CENTER\"\n'
+'demix.demixtimeresolution=10\n'
+'demix.demixfreqresolution='+str(bandwidth)+'kHz\n'
+'demix.uselbfgssolver=true\n'
+'demix.lbfgs.historysize=10\n'
+'demix.lbfgs.robustdof=200\n'
+'demix.freqstep=1\n'
+'demix.timestep=1\n'
+'demix.instrumentmodel=instrument_'+str(channel_index)+'\n'
+'demix.skymodel='+self.final_skymodel+'\n'
)
bbsdem.write('demix.subtractsources=[')
firstcomma=False
for patch in self.chosen_patches:
if not firstcomma:
firstcomma=True
else:
bbsdem.write(',')
bbsdem.write('\"'+patch+'\"')
bbsdem.write(']\n')
bbsdem.close()
return parset_name
def create_parset_iter(self,iter_index):
parset_name=str(iter_index)+'_'+self.parset_demix
bbsdem=open(parset_name,'w+')
bbsdem.write('steps=[demix]\n'
+'demix.type=demixer\n'
+'demix.baseline=[CR]S*&\n'
+'demix.ignoretarget=False\n'
+'demix.targetsource=\"CENTER\"\n'
+'demix.demixtimeresolution=10\n'
+'demix.demixfreqresolution='+str(self.bandwidth_to_use)+'\n'
+'demix.uselbfgssolver=true\n'
+'demix.lbfgs.historysize=10\n'
+'demix.lbfgs.robustdof=200\n'
+'demix.maxiter='+str(self.maxiter[iter_index])+'\n'
+'demix.freqstep=1\n'
+'demix.timestep=1\n'
+'demix.instrumentmodel=instrument_'+str(iter_index)+'\n'
+'demix.skymodel='+self.final_skymodel+'\n'
)
bbsdem.write('demix.subtractsources=[')
firstcomma=False
for patch in self.chosen_patches:
if not firstcomma:
firstcomma=True
else:
bbsdem.write(',')
bbsdem.write('\"'+patch+'\"')
bbsdem.write(']\n')
bbsdem.close()
return parset_name
def iter_over_directions(self):
self.RSS=np.zeros((self.Nf,2**self.K))
shmRSS=shared_memory.SharedMemory(create=True,size=self.RSS.nbytes)
@@ -277,7 +362,7 @@ class DemixingTuner:
chosen_dirs=self.scalar_to_kvec(index)
low_elevations=any(ele <= self.low_elevation_limit for ele in itertools.compress(self.elevation,chosen_dirs))
if not low_elevations:
parset_name=self.create_parset(index)
parset_name=self.create_parset_dir(index)
# copy data
ci=0
for ms in self.extracted_ms:
@@ -291,12 +376,11 @@ class DemixingTuner:
proc1.wait()
res_sigmaI,res_sigmaQ,res_sigmaU,res_sigmaV=self.get_noise_var(demixout_ms)
print('%s %d %f %f %f %f'%(out_ms,index,res_sigmaI,res_sigmaQ,res_sigmaU,res_sigmaV))
#RSSsh[ci,index]=(res_sigmaI+res_sigmaQ+res_sigmaU+res_sigmaV)/4
RSSsh[ci,index]=res_sigmaI
ci+=1
else:
RSSsh[:,index]=1e9
pool=Pool(self.Nparallel)
pool.map(process_scenario,range(1,2**self.K))
pool.close()
@@ -311,6 +395,77 @@ class DemixingTuner:
self.RSS[ci,0]=sigmaI
ci+=1
def iter_over_frequency(self):
assert(len(self.chosen_patches)>0)
# channels to consider
channels=[2**x for x in range(int(np.ceil(np.log(self.nchan)/np.log(2)))+1)]
self.RSS_freq=np.zeros((self.Nf,len(channels)))
shmRSS=shared_memory.SharedMemory(create=True,size=self.RSS_freq.nbytes)
RSSsh=np.ndarray(self.RSS_freq.shape,dtype=self.RSS_freq.dtype,buffer=shmRSS.buf)
@globalize
def process_scenario_freq(channel_index):
parset_name=self.create_parset_freq(channel_index,channels)
# copy data
ci=0
for ms in self.extracted_ms:
out_ms='out_'+str(channel_index)+'_'+ms
sb.run('rm -rf '+out_ms,shell=True)
proc1=sb.Popen('rsync -a '+ms+'/ '+out_ms,shell=True)
proc1.wait()
demixout_ms='residual_'+out_ms
sb.run('rm -rf '+demixout_ms,shell=True)
proc1=sb.Popen(self.DP3+' '+parset_name+' msin='+out_ms+' msout='+demixout_ms,shell=True)
proc1.wait()
res_sigmaI,res_sigmaQ,res_sigmaU,res_sigmaV=self.get_noise_var(demixout_ms)
print('%s %d %f %f %f %f'%(out_ms,channel_index,res_sigmaI,res_sigmaQ,res_sigmaU,res_sigmaV))
RSSsh[ci,channel_index]=res_sigmaI
ci+=1
pool=Pool(self.Nparallel)
pool.map(process_scenario_freq,range(len(channels)))
pool.close()
pool.join()
self.RSS_freq[:]=RSSsh[:]
shmRSS.close()
shmRSS.unlink()
def iter_over_maxiter(self):
assert(len(self.chosen_patches)>0)
assert(len(self.maxiter)>0)
self.RSS_iter=np.zeros((self.Nf,len(self.maxiter)))
shmRSS=shared_memory.SharedMemory(create=True,size=self.RSS_iter.nbytes)
RSSsh=np.ndarray(self.RSS_iter.shape,dtype=self.RSS_iter.dtype,buffer=shmRSS.buf)
@globalize
def process_scenario_iter(iter_index):
parset_name=self.create_parset_iter(iter_index)
# copy data
ci=0
for ms in self.extracted_ms:
out_ms='out_'+str(iter_index)+'_'+ms
sb.run('rm -rf '+out_ms,shell=True)
proc1=sb.Popen('rsync -a '+ms+'/ '+out_ms,shell=True)
proc1.wait()
demixout_ms='residual_'+out_ms
sb.run('rm -rf '+demixout_ms,shell=True)
proc1=sb.Popen(self.DP3+' '+parset_name+' msin='+out_ms+' msout='+demixout_ms,shell=True)
proc1.wait()
res_sigmaI,res_sigmaQ,res_sigmaU,res_sigmaV=self.get_noise_var(demixout_ms)
print('%s %d %f %f %f %f'%(out_ms,iter_index,res_sigmaI,res_sigmaQ,res_sigmaU,res_sigmaV))
RSSsh[ci,iter_index]=res_sigmaI
ci+=1
pool=Pool(self.Nparallel)
pool.map(process_scenario_iter,range(len(self.maxiter)))
pool.close()
pool.join()
self.RSS_iter[:]=RSSsh[:]
shmRSS.close()
shmRSS.unlink()
# reset weights to 1, after applying weight to datum
def reset_weights(self,msname):
@@ -354,17 +509,37 @@ class DemixingTuner:
rss=(np.mean(self.RSS[:,index])/np.mean(self.RSS[:,0]))**2
self.AIC[index]=rss*self.N*self.N+np.sum(chosen_dirs)*self.N
def report(self):
indx=np.argsort(self.AIC[1:])+1
chosen_dirs=np.zeros(self.K)
for index in indx[:2]:
for index in indx[:self.inclusiveness]:
chosen_dirs+=self.scalar_to_kvec(index)
# remove -ve elevation dirs, again, to be safe
chosen_dirs[self.elevation<=self.low_elevation_limit]=0
chosen_patches=itertools.compress(self.ateam_names,chosen_dirs)
self.chosen_patches=list(itertools.compress(self.ateam_names,chosen_dirs))
def calc_bandwidth_to_use(self):
channels=[2**x for x in range(int(np.ceil(np.log(self.nchan)/np.log(2)))+1)]
for ci in range(self.Nf):
self.RSS_freq[ci,:]/=self.RSS[ci,0]
rss=np.mean(self.RSS_freq,axis=0)**2
self.AIC_freq=rss*self.N*self.N+self.nchan/channels*self.N*len(self.chosen_patches)
index=np.argmin(self.AIC_freq)
self.bandwidth_to_use=self.bandwidth*channels[index]/self.nchan
def calc_maxiter_to_use(self):
for ci in range(self.Nf):
self.RSS_iter[ci,:]/=self.RSS[ci,0]
rss=np.mean(self.RSS_iter,axis=0)**2
self.AIC_iter=rss*self.N*self.N+self.N*len(self.chosen_patches)*np.array(self.maxiter)
index=np.argmin(self.AIC_iter)
self.maxiter_to_use=self.maxiter[index]
def report(self):
print('***************** Result:')
for patch in chosen_patches:
print(patch)
print(f'Bandwidth {self.bandwidth_to_use}')
print(f'Maxiter {self.maxiter_to_use}')
print(self.chosen_patches)
print('*************************')
def cleanup(self):
@@ -386,6 +561,12 @@ class DemixingTuner:
self.iter_over_directions()
self.cleanup()
self.calc_AIC()
self.iter_over_frequency()
self.calc_bandwidth_to_use()
self.cleanup()
self.iter_over_maxiter()
self.calc_maxiter_to_use()
self.cleanup()
self.report()
Loading