#! /usr/bin/env python """ Script to tune demixing parameters by sampling data """ import math,sys,uuid import subprocess as sb import time,glob import numpy as np import argparse import itertools import casacore.tables as ctab from casacore.measures import measures from casacore.quanta import quantity from multiprocessing import Pool from multiprocessing import shared_memory import astropy.time as atime import lsmtool import logging # default DP3 default_DP3='DP3' # LINC script to download target sky default_LINC_GET_TARGET='download_skymodel_target.py' def globalize(func): def result(*args, **kwargs): return func(*args, **kwargs) result.__name__ = result.__qualname__ = uuid.uuid4().hex setattr(sys.modules[result.__module__], result.__name__, result) return result class DemixingTuner: DP3=default_DP3 LINC_GET_TARGET=default_LINC_GET_TARGET INF=1e15 def __init__(self,mslist,skymodel,target_skymodel,timesec,Nf,sblist,Nparallel=4,path_DP3=None,path_LINC=None, patch_list=[]): self.mslist=glob.glob(mslist) self.Nf=Nf if len(sblist)>0: # extract subbands given by input new_mslist=self.select_subbands(self.mslist,sblist) self.mslist=new_mslist self.Nf=len(self.mslist) self.timesec=timesec self.skymodel=skymodel self.target_skymodel_in=target_skymodel # max iterations to consider self.maxiter=[20, 30, 40, 50] # robust degrees of freedom to consider self.dof=[5, 20, 200] if path_DP3: self.DP3=path_DP3 if path_LINC: self.LINC_GET_TARGET=path_LINC self.extracted_ms=[] # stations self.N=0 # outlier directions self.K=0 # baselines self.B=0 # frequencies self.freqlist=None # bandwidth (of each subband) self.bandwidth=0 # best bandwidth (freq resolution) to use self.bandwidth_to_use=0 # best value for maximum iterations to use self.maxiter_to_use=0 # best value for robust DOF self.robust_dof_to_use=0 # number of channels (of each subband) self.nchan=0 # target direction self.ra0=0 self.dec0=0 # measures object self.mydm=None # outlier directions (arrays) self.ateam_ra=None self.ateam_dec=None # outlier separations,azimuth,elevations self.separation=None self.azimuth=None self.elevation=None # lowest elevation (degrees) to ignore any outlier source self.low_elevation_limit=5 # lowest separation (degrees) to forcefully include self.low_separation_limit=20 # outlier (Patch) names self.ateam_names=None self.chosen_patches=None self.given_patches=patch_list # 1: only select to minimum number of sources, increase to include more # 2: for HBA, 3: for LBA, this will be automatically set self.inclusiveness=2 self.target_skymodel='./target.sky.txt' self.final_skymodel='./combined.sky.txt' self.parset_demix='tunedemix.parset' self.Nparallel=Nparallel # RSS : sqrt(residual sum of squares) self.RSS=None self.AIC=None self.AIC_probs=None self.RSS_freq=None self.AIC_freq=None self.RSS_iter=None self.AIC_iter=None self.RSS_dof=None self.AIC_dof=None def __del__(self): for ms in self.extracted_ms: sb.run('rm -rf '+ms,shell=True) def select_subbands(self,mslist,subband_ids): sbidx=['SB{:03d}'.format(int(x)) for x in subband_ids] selected_ms=[] for ms in mslist: for sb in sbidx: if sb in ms: selected_ms.append(ms) return selected_ms def extract_dataset(self): self.mslist.sort() msname=self.mslist[0] # Parset for extracting and averaging parset_sample='extract_sample.parset' parset=open(parset_sample,'w+') # in order to overcome storage manager issues, make a # copy of data, only antenna 1 parset.write('steps=[fil]\n' +'fil.type=filter\n' +'fil.baseline=\"0,0 &&\"\n' +'fil.remove=True\n' +'msin.datacolumn=DATA\n') parset.close() # remove old files sb.run('rm -rf L_SB0.MS',shell=True) # now process first selected MS MS='L_SB0.MS' proc1=sb.Popen(self.DP3+' '+parset_sample+' msin='+msname+' msout='+MS, shell=True) proc1.wait() msname=MS tt=ctab.table(msname,readonly=True) starttime= tt[0]['TIME'] endtime=tt[tt.nrows()-1]['TIME'] self.N=tt.nrows() tt.close() Nms=len(self.mslist) # need to have at least Nf MS assert(Nms>=self.Nf) parset=open(parset_sample,'w+') # sample time interval t_start=np.random.rand()*(endtime-starttime)+starttime t_end=t_start+self.timesec t0=atime.Time(t_start/(24*60*60),format='mjd',scale='utc') dt=t0.to_datetime() str_tstart=str(dt.year)+'/'+str(dt.month)+'/'+str(dt.day)+'/'+str(dt.hour)+':'+str(dt.minute)+':'+str(dt.second) t0=atime.Time(t_end/(24*60*60),format='mjd',scale='utc') dt=t0.to_datetime() str_tend=str(dt.year)+'/'+str(dt.month)+'/'+str(dt.day)+'/'+str(dt.hour)+':'+str(dt.minute)+':'+str(dt.second) parset.write('steps=[fil,avg]\n' +'fil.type=filter\n' +'fil.baseline=[CR]S*&\n' +'fil.remove=True\n' +'avg.type=average\n' +'avg.timestep=1\n' +'avg.freqstep=1\n' +'msin.datacolumn=DATA\n' +'msin.starttime='+str_tstart+'\n' +'msin.endtime='+str_tend+'\n') parset.close() # process subset of MS from mslist submslist=list() submslist.append(self.mslist[0]) aa=list(np.random.choice(np.arange(1,Nms-1),self.Nf-2,replace=False)) aa.sort() for ms_id in aa: submslist.append(self.mslist[ms_id]) submslist.append(self.mslist[-1]) # remove old files sb.run('rm -rf L_SB*.MS',shell=True) # now process each of selected MS self.extracted_ms=list() for ci in range(self.Nf): MS='L_SB'+str(ci)+'.MS' proc1=sb.Popen(self.DP3+' '+parset_sample+' msin='+submslist[ci]+' msout='+MS, shell=True) proc1.wait() self.extracted_ms.append(MS) # also update the weights self.reset_weights(MS) # get metadata from extracted MS self.freqlist=np.zeros(self.Nf) ci=0 for ms in self.extracted_ms: tf=ctab.table(ms+'/SPECTRAL_WINDOW',readonly=True) ch0=tf.getcol('CHAN_FREQ') reffreq=tf.getcol('REF_FREQUENCY') self.freqlist[ci]=ch0[0,0] tf.close() ci+=1 # update inclusiveness is not set if np.min(self.freqlist)<90e6: self.inclusiveness=3 else: self.inclusiveness=2 tf=ctab.table(self.extracted_ms[0]+'/SPECTRAL_WINDOW',readonly=True) self.bandwidth=tf.getcol('TOTAL_BANDWIDTH')[0] self.nchan=tf.getcol('NUM_CHAN')[0] tf.close() # get antennas ant=ctab.table(self.extracted_ms[0]+'/ANTENNA',readonly=True) self.N=ant.nrows() # baselines self.B=self.N*(self.N-1)//2 # Antenna location xyz=ant.getcol('POSITION') ant.close() # Get target coords field=ctab.table(self.extracted_ms[0]+'/FIELD',readonly=True) phase_dir=field.getcol('PHASE_DIR') self.ra0=phase_dir[0][0][0] self.dec0=phase_dir[0][0][1] field.close() # get integration time tt=ctab.table(self.extracted_ms[0],readonly=True) tt1=tt.getcol('INTERVAL') Tdelta=tt[0]['INTERVAL'] t0=tt[0]['TIME'] tt.close() Tslots=math.ceil(self.timesec/Tdelta) # epoch coordinate UTC self.mydm=measures() mypos=self.mydm.position('ITRF',str(xyz[0][0])+'m',str(xyz[0][1])+'m',str(xyz[0][2])+'m') mytime=self.mydm.epoch('UTC',str(t0)+'s') self.mydm.doframe(mytime) self.mydm.doframe(mypos) def read_skymodel(self): lsm=lsmtool.load(self.skymodel) skymodel_patches=lsm.getPatchNames() # only select patches already given by user if len(self.given_patches)>0: patches=list(set(self.given_patches).intersection(skymodel_patches)) else: patches=skymodel_patches assert(len(patches)>0) self.ateam_ra=np.zeros(len(patches)) self.ateam_dec=np.zeros(len(patches)) self.K=len(patches) self.ateam_names=patches for ci in range(self.K): patch=patches[ci] lsm=lsmtool.load(self.skymodel) lsm.select('Patch=='+patch) t=lsm.table self.ateam_ra[ci]=np.array(t['Ra'])[0] self.ateam_dec[ci]=np.array(t['Dec'])[0] self.separation=np.zeros(self.K) self.azimuth=np.zeros(self.K) self.elevation=np.zeros(self.K) ra0_q=quantity(self.ra0,'rad') dec0_q=quantity(self.dec0,'rad') target=self.mydm.direction('j2000',ra0_q,dec0_q) for ci in range(self.K): mra_q=quantity(self.ateam_ra[ci],'deg') mdec_q=quantity(self.ateam_dec[ci],'deg') cluster_dir=self.mydm.direction('j2000',mra_q,mdec_q) sep=self.mydm.separation(target,cluster_dir) self.separation[ci]=sep.get_value() azel=self.mydm.measure(cluster_dir,'AZEL') self.azimuth[ci]=azel['m0']['value']/math.pi*180 self.elevation[ci]=azel['m1']['value']/math.pi*180 def get_target_skymodel(self,overwrite=True): # If target is given as input, use it if self.target_skymodel_in: self.target_skymodel=self.target_skymodel_in else: if overwrite: sb.run('rm -rvf '+self.target_skymodel,shell=True) sb.run('python3 '+self.LINC_GET_TARGET+' --Radius 5 --targetname CENTER '+self.extracted_ms[0]+' '+self.target_skymodel,shell=True) # convert integer to binary bits, return array of size K def scalar_to_kvec(self,n): ll=[1 if digit=='1' else 0 for digit in bin(n)[2:]] a=np.zeros(self.K) a[-len(ll):]=ll return a # merge target with outliers given by def create_final_skymodel(self): lsm0=lsmtool.load(self.target_skymodel) lsm=lsmtool.load(self.skymodel) lsm0.concatenate(lsm,keep='all') lsm0.write(self.final_skymodel,clobber=True) # merge target with outliers given by # category index 'index' def create_parset_dir(self,index): chosen_dirs=self.scalar_to_kvec(index) chosen_patches=itertools.compress(self.ateam_names,chosen_dirs) parset_name=str(index)+'_'+self.parset_demix bbsdem=open(parset_name,'w+') bbsdem.write('steps=[demix]\n' +'demix.type=demixer\n' +'demix.baseline=[CR]S*&\n' +'demix.ignoretarget=False\n' +'demix.targetsource=\"CENTER\"\n' +'demix.demixtimeresolution=10\n' +'demix.demixfreqresolution=50kHz\n' +'demix.uselbfgssolver=true\n' +'demix.lbfgs.historysize=10\n' +'demix.lbfgs.robustdof=200\n' +'demix.freqstep=1\n' +'demix.timestep=1\n' +'demix.instrumentmodel=instrument_'+str(index)+'\n' +'demix.skymodel='+self.final_skymodel+'\n' ) bbsdem.write('demix.subtractsources=[') firstcomma=False if sum(chosen_dirs)>0: for patch in chosen_patches: if not firstcomma: firstcomma=True else: bbsdem.write(',') bbsdem.write('\"'+patch+'\"') bbsdem.write(']\n') bbsdem.close() return parset_name def create_parset_freq(self,channel_index,channels): bandwidth=self.bandwidth/self.nchan*channels[channel_index]/1e3 parset_name=str(channel_index)+'_'+self.parset_demix bbsdem=open(parset_name,'w+') bbsdem.write('steps=[demix]\n' +'demix.type=demixer\n' +'demix.baseline=[CR]S*&\n' +'demix.ignoretarget=False\n' +'demix.targetsource=\"CENTER\"\n' +'demix.demixtimeresolution=10\n' +'demix.demixfreqresolution='+str(bandwidth)+'kHz\n' +'demix.uselbfgssolver=true\n' +'demix.lbfgs.historysize=10\n' +'demix.lbfgs.robustdof=200\n' +'demix.freqstep=1\n' +'demix.timestep=1\n' +'demix.instrumentmodel=instrument_'+str(channel_index)+'\n' +'demix.skymodel='+self.final_skymodel+'\n' ) bbsdem.write('demix.subtractsources=[') firstcomma=False for patch in self.chosen_patches: if not firstcomma: firstcomma=True else: bbsdem.write(',') bbsdem.write('\"'+patch+'\"') bbsdem.write(']\n') bbsdem.close() return parset_name def create_parset_iter(self,iter_index): parset_name=str(iter_index)+'_'+self.parset_demix bbsdem=open(parset_name,'w+') bbsdem.write('steps=[demix]\n' +'demix.type=demixer\n' +'demix.baseline=[CR]S*&\n' +'demix.ignoretarget=False\n' +'demix.targetsource=\"CENTER\"\n' +'demix.demixtimeresolution=10\n' +'demix.demixfreqresolution='+str(self.bandwidth_to_use)+'\n' +'demix.uselbfgssolver=true\n' +'demix.lbfgs.historysize=10\n' +'demix.lbfgs.robustdof='+str(self.robust_dof_to_use)+'\n' +'demix.maxiter='+str(self.maxiter[iter_index])+'\n' +'demix.freqstep=1\n' +'demix.timestep=1\n' +'demix.instrumentmodel=instrument_'+str(iter_index)+'\n' +'demix.skymodel='+self.final_skymodel+'\n' ) bbsdem.write('demix.subtractsources=[') firstcomma=False for patch in self.chosen_patches: if not firstcomma: firstcomma=True else: bbsdem.write(',') bbsdem.write('\"'+patch+'\"') bbsdem.write(']\n') bbsdem.close() return parset_name def create_parset_dof(self,dof_index): parset_name=str(dof_index)+'_'+self.parset_demix bbsdem=open(parset_name,'w+') bbsdem.write('steps=[demix]\n' +'demix.type=demixer\n' +'demix.baseline=[CR]S*&\n' +'demix.ignoretarget=False\n' +'demix.targetsource=\"CENTER\"\n' +'demix.demixtimeresolution=10\n' +'demix.demixfreqresolution='+str(self.bandwidth_to_use)+'\n' +'demix.uselbfgssolver=true\n' +'demix.lbfgs.historysize=10\n' +'demix.lbfgs.robustdof='+str(self.dof[dof_index])+'\n' +'demix.maxiter=50\n' +'demix.freqstep=1\n' +'demix.timestep=1\n' +'demix.instrumentmodel=instrument_'+str(dof_index)+'\n' +'demix.skymodel='+self.final_skymodel+'\n' ) bbsdem.write('demix.subtractsources=[') firstcomma=False for patch in self.chosen_patches: if not firstcomma: firstcomma=True else: bbsdem.write(',') bbsdem.write('\"'+patch+'\"') bbsdem.write(']\n') bbsdem.close() return parset_name def iter_over_directions(self): self.RSS=np.zeros((self.Nf,2**self.K)) shmRSS=shared_memory.SharedMemory(create=True,size=self.RSS.nbytes) RSSsh=np.ndarray(self.RSS.shape,dtype=self.RSS.dtype,buffer=shmRSS.buf) @globalize def process_scenario(index): # if any of the sources in this index has -ve elevation, # stop and return a higher error chosen_dirs=self.scalar_to_kvec(index) low_elevations=any(ele <= self.low_elevation_limit for ele in itertools.compress(self.elevation,chosen_dirs)) if not low_elevations: parset_name=self.create_parset_dir(index) # copy data ci=0 for ms in self.extracted_ms: out_ms='out_'+str(index)+'_'+ms sb.run('rm -rf '+out_ms,shell=True) proc1=sb.Popen('rsync -a '+ms+'/ '+out_ms,shell=True) proc1.wait() demixout_ms='residual_'+out_ms sb.run('rm -rf '+demixout_ms,shell=True) proc1=sb.Popen(self.DP3+' '+parset_name+' msin='+out_ms+' msout='+demixout_ms,shell=True) proc1.wait() res_sigmaI,res_sigmaQ,res_sigmaU,res_sigmaV=self.get_noise_var(demixout_ms) print('%s %d %e %e %e %e'%(out_ms,index,res_sigmaI,res_sigmaQ,res_sigmaU,res_sigmaV)) RSSsh[ci,index]=res_sigmaI ci+=1 else: RSSsh[:,index]=self.INF pool=Pool(self.Nparallel) pool.map(process_scenario,range(1,2**self.K)) pool.close() pool.join() self.RSS[:]=RSSsh[:] shmRSS.close() shmRSS.unlink() ci=0 for ms in self.extracted_ms: sigmaI,sigmaQ,sigmaU,sigmaV=self.get_noise_var(ms) self.RSS[ci,0]=sigmaI ci+=1 def iter_over_frequency(self): assert(len(self.chosen_patches)>0) # channels to consider channels=[2**x for x in range(int(np.ceil(np.log(self.nchan)/np.log(2)))+1)] self.RSS_freq=np.zeros((self.Nf,len(channels))) shmRSS=shared_memory.SharedMemory(create=True,size=self.RSS_freq.nbytes) RSSsh=np.ndarray(self.RSS_freq.shape,dtype=self.RSS_freq.dtype,buffer=shmRSS.buf) @globalize def process_scenario_freq(channel_index): parset_name=self.create_parset_freq(channel_index,channels) # copy data ci=0 for ms in self.extracted_ms: out_ms='out_'+str(channel_index)+'_'+ms sb.run('rm -rf '+out_ms,shell=True) proc1=sb.Popen('rsync -a '+ms+'/ '+out_ms,shell=True) proc1.wait() demixout_ms='residual_'+out_ms sb.run('rm -rf '+demixout_ms,shell=True) proc1=sb.Popen(self.DP3+' '+parset_name+' msin='+out_ms+' msout='+demixout_ms,shell=True) proc1.wait() res_sigmaI,res_sigmaQ,res_sigmaU,res_sigmaV=self.get_noise_var(demixout_ms) print('%s %d %e %e %e %e'%(out_ms,channel_index,res_sigmaI,res_sigmaQ,res_sigmaU,res_sigmaV)) RSSsh[ci,channel_index]=res_sigmaI ci+=1 pool=Pool(self.Nparallel) pool.map(process_scenario_freq,range(len(channels))) pool.close() pool.join() self.RSS_freq[:]=RSSsh[:] shmRSS.close() shmRSS.unlink() def iter_over_maxiter(self): assert(len(self.chosen_patches)>0) assert(len(self.maxiter)>0) self.RSS_iter=np.zeros((self.Nf,len(self.maxiter))) shmRSS=shared_memory.SharedMemory(create=True,size=self.RSS_iter.nbytes) RSSsh=np.ndarray(self.RSS_iter.shape,dtype=self.RSS_iter.dtype,buffer=shmRSS.buf) @globalize def process_scenario_iter(iter_index): parset_name=self.create_parset_iter(iter_index) # copy data ci=0 for ms in self.extracted_ms: out_ms='out_'+str(iter_index)+'_'+ms sb.run('rm -rf '+out_ms,shell=True) proc1=sb.Popen('rsync -a '+ms+'/ '+out_ms,shell=True) proc1.wait() demixout_ms='residual_'+out_ms sb.run('rm -rf '+demixout_ms,shell=True) proc1=sb.Popen(self.DP3+' '+parset_name+' msin='+out_ms+' msout='+demixout_ms,shell=True) proc1.wait() res_sigmaI,res_sigmaQ,res_sigmaU,res_sigmaV=self.get_noise_var(demixout_ms) print('%s %d %e %e %e %e'%(out_ms,iter_index,res_sigmaI,res_sigmaQ,res_sigmaU,res_sigmaV)) RSSsh[ci,iter_index]=res_sigmaI ci+=1 pool=Pool(self.Nparallel) pool.map(process_scenario_iter,range(len(self.maxiter))) pool.close() pool.join() self.RSS_iter[:]=RSSsh[:] shmRSS.close() shmRSS.unlink() def iter_over_robust_dof(self): assert(len(self.chosen_patches)>0) assert(len(self.dof)>0) self.RSS_dof=np.zeros((self.Nf,len(self.dof))) shmRSS=shared_memory.SharedMemory(create=True,size=self.RSS_dof.nbytes) RSSsh=np.ndarray(self.RSS_dof.shape,dtype=self.RSS_dof.dtype,buffer=shmRSS.buf) @globalize def process_scenario_dof(iter_index): parset_name=self.create_parset_dof(iter_index) # copy data ci=0 for ms in self.extracted_ms: out_ms='out_'+str(iter_index)+'_'+ms sb.run('rm -rf '+out_ms,shell=True) proc1=sb.Popen('rsync -a '+ms+'/ '+out_ms,shell=True) proc1.wait() demixout_ms='residual_'+out_ms sb.run('rm -rf '+demixout_ms,shell=True) proc1=sb.Popen(self.DP3+' '+parset_name+' msin='+out_ms+' msout='+demixout_ms,shell=True) proc1.wait() res_sigmaI,res_sigmaQ,res_sigmaU,res_sigmaV=self.get_noise_var(demixout_ms) print('%s %d %e %e %e %e'%(out_ms,iter_index,res_sigmaI,res_sigmaQ,res_sigmaU,res_sigmaV)) RSSsh[ci,iter_index]=res_sigmaI ci+=1 pool=Pool(self.Nparallel) pool.map(process_scenario_dof,range(len(self.dof))) pool.close() pool.join() self.RSS_dof[:]=RSSsh[:] shmRSS.close() shmRSS.unlink() # reset weights to 1, after applying weight to datum def reset_weights(self,msname): tt=ctab.table(msname,readonly=False) data=tt.getcol('DATA') if 'WEIGHT_SPECTRUM' in tt.colnames(): weight=tt.getcol('WEIGHT_SPECTRUM') else: weight=tt.getcol('IMAGING_WEIGHT') # weight ~sigma^2, std(weight) ~ sigma^2, so sqrt() weight_std=np.sqrt(weight.std()) data *=weight_std weight /=weight_std tt.putcol('DATA',data) tt.putcol('WEIGHT_SPECTRUM',weight) tt.close() # extract noise info def get_noise_var(self,msname): tt=ctab.table(msname,readonly=True) t1=tt.query('ANTENNA1 != ANTENNA2',columns='DATA,FLAG') data0=t1.getcol('DATA') # set nans to 0 flag=t1.getcol('FLAG') flag[np.isnan(flag)]=1 data=data0*(1-flag) tt.close() # set nans to 0 data[np.isnan(data)]=0. # form IQUV sI=(data[:,:,0]+data[:,:,3])*0.5 sQ=(data[:,:,0]-data[:,:,3])*0.5 sU=(data[:,:,1]-data[:,:,2])*0.5 sV=(data[:,:,1]+data[:,:,2])*0.5 return sI.std(),sQ.std(),sU.std(),sV.std() # calculate AIC = RSS * N_stat *N_stat + K N_stat # AIC = (noise RSS/ data RSS) * N_stat *N_stat + K * N_stat def calc_AIC(self): self.AIC=np.zeros(2**self.K) for index in range(2**self.K): chosen_dirs=self.scalar_to_kvec(index) rss=(np.mean(self.RSS[:,index])/np.mean(self.RSS[:,0]))**2 self.AIC[index]=rss*self.N*self.N+np.sum(chosen_dirs)*self.N indx=np.argsort(self.AIC[1:])+1 probs=np.exp(-self.AIC/self.AIC[indx[0]])/np.sum(np.exp(-self.AIC/self.AIC[indx[0]])) self.AIC_probs=np.zeros(self.K) for ci in range(2**self.K): self.AIC_probs+=probs[ci]*self.scalar_to_kvec(ci) AIC_probs=self.AIC_probs.copy() chosen_dirs=np.zeros(self.K) for index in range(self.inclusiveness): max_id=np.argmax(AIC_probs) chosen_dirs[max_id]=1 AIC_probs[max_id]=0 # remove -ve elevation dirs, again, to be safe chosen_dirs[self.elevation<=self.low_elevation_limit]=0 # also add sources that are close close_sources=[index for index in range(self.K) if self.separation[index]<=self.low_separation_limit] chosen_dirs[self.separation<=self.low_separation_limit]=1 chosen_patches=list(itertools.compress(self.ateam_names,chosen_dirs)) chosen_separation=list(itertools.compress(self.separation,chosen_dirs)) sorted_patches=[patch for _,patch in sorted(zip(chosen_separation,chosen_patches))] if len(chosen_patches)>self.inclusiveness and len(close_sources)==0: self.chosen_patches=sorted_patches[:-1] else: self.chosen_patches=sorted_patches def calc_bandwidth_to_use(self): channels=[2**x for x in range(int(np.ceil(np.log(self.nchan)/np.log(2)))+1)] for ci in range(self.Nf): self.RSS_freq[ci,:]/=self.RSS[ci,0] rss=np.mean(self.RSS_freq,axis=0)**2 self.AIC_freq=rss*self.N*self.N+self.nchan/channels*self.N*len(self.chosen_patches) index=np.argmin(self.AIC_freq) self.bandwidth_to_use=self.bandwidth*channels[index]/self.nchan def calc_maxiter_to_use(self): for ci in range(self.Nf): self.RSS_iter[ci,:]/=self.RSS[ci,0] rss=np.mean(self.RSS_iter,axis=0)**2 self.AIC_iter=rss*self.N*self.N+self.N*len(self.chosen_patches)*np.log(np.array(self.maxiter)) index=np.argmin(self.AIC_iter) self.maxiter_to_use=self.maxiter[index] def calc_dof_to_use(self): for ci in range(self.Nf): self.RSS_dof[ci,:]/=self.RSS[ci,0] rss=np.mean(self.RSS_dof,axis=0)**2 self.AIC_dof=rss*self.N*self.N+np.array(self.dof)*self.N index=np.argmin(self.AIC_dof) self.robust_dof_to_use=self.dof[index] def report(self): info={} print('*************************') print('Patches') print(self.ateam_names) info['patches']=self.ateam_names print('Probabilities of being selected') print(self.AIC_probs) info['probabilities']=self.AIC_probs print('Separation') print(self.separation) info['separation']=self.separation print('Elevation') print(self.elevation) info['elevation']=self.elevation print('Direction selection') print(self.RSS) info['RSS_dir']=self.RSS print(self.AIC) info['AIC_dir']=self.AIC print('Frequency resolution selection') print(self.RSS_freq) info['RSS_freq']=self.RSS_freq print(self.AIC_freq) info['AIC_freq']=self.AIC_freq print('Robust DOF selection') print(self.RSS_dof) info['RSS_dof']=self.RSS_dof print(self.AIC_dof) info['AIC_dof']=self.AIC_dof print('Maxiter selection') print(self.RSS_iter) info['RSS_iter']=self.RSS_iter print(self.AIC_iter) info['AIC_iter']=self.AIC_iter print('Expected reduction in noise is %f %%'%((1-np.mean(self.RSS_iter[:,np.argmin(self.AIC_iter)]))*100.0)) print('*************************') return info def cleanup(self): for ms in self.extracted_ms: out_ms='out_*_'+ms sb.run('rm -rf '+out_ms,shell=True) demixout_ms='residual_'+out_ms sb.run('rm -rf '+demixout_ms,shell=True) parset_name='*_'+self.parset_demix sb.run('rm -rf '+parset_name,shell=True) sb.run('rm -rf instrument_*',shell=True) def run(self): # Note: the order of following being run is important self.extract_dataset() self.read_skymodel() self.get_target_skymodel(overwrite=True) self.create_final_skymodel() self.iter_over_directions() self.cleanup() self.calc_AIC() self.iter_over_frequency() self.calc_bandwidth_to_use() self.cleanup() self.iter_over_robust_dof() self.calc_dof_to_use() self.cleanup() self.iter_over_maxiter() self.calc_maxiter_to_use() self.cleanup() info=self.report() return info def main(args): tuner=DemixingTuner(args.MS,args.sky_model,args.target_sky_model,args.time_interval,args.subbands,args.subband_list,args.parallel_jobs,path_DP3=args.DP3, path_LINC=args.LINC, patch_list=args.patches) info=tuner.run() result={} result['maxiter']=tuner.maxiter_to_use result['demixfreqresolution']=tuner.bandwidth_to_use result['subtractsources']=tuner.chosen_patches result['lbfgs.robustdof']=tuner.robust_dof_to_use result['demixtimeresolution']=10 result['lbfgs.historysize']=10 return result,info if __name__=='__main__': parser=argparse.ArgumentParser( description='Tune demixing parameters', formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument('--MS',type=str,metavar='\'*.MS\'', help='absolute path of MS pattern to use') parser.add_argument('--sky_model',type=str,metavar='s', help='A-Team sky model (text)') parser.add_argument('--target_sky_model',type=str,default=None,metavar='st', help='target sky model (text)') parser.add_argument('--subbands',default=3,type=int,metavar='f', help='number of randomly selected subbands to process (minimum 3)') parser.add_argument('--subband_list',nargs='+',default=[], type=int,metavar='sb', help='list of subband numbers [0,1,..] to process (will override number of subbands)') parser.add_argument('--time_interval',type=float,default=30.,metavar='t', help='total time interval to sample in seconds') parser.add_argument('--parallel_jobs',type=int,default=4,metavar='p', help='number of parallel jobs') parser.add_argument('--patches',nargs='+',default=[], metavar='Patch1', help='names of patches to subtract if found in sky model, if None, all patches') parser.add_argument('--DP3',type=str,metavar='DP3',default=default_DP3, help='DP3 command') parser.add_argument('--LINC',type=str,metavar='LINC',default=default_LINC_GET_TARGET, help='path to download_skymodel_target.py script (not used if target sky model is given)') args=parser.parse_args() format_stream = logging.Formatter("%(asctime)s\033[1m %(levelname)s:\033[0m %(message)s","%Y-%m-%d %H:%M:%S") format_file = logging.Formatter("%(asctime)s %(levelname)s: %(message)s","%Y-%m-%d %H:%M:%S") logging.root.setLevel(logging.INFO) log = logging.StreamHandler() log.setFormatter(format_stream) logging.root.addHandler(log) if args.MS and args.sky_model: res,info=main(args) print(res) else: parser.print_help()