diff --git a/scripts/tune_demixing_parameters.py b/scripts/tune_demixing_parameters.py new file mode 100755 index 0000000000000000000000000000000000000000..268593b85651a4c57c51d84691d12cc83e1584a3 --- /dev/null +++ b/scripts/tune_demixing_parameters.py @@ -0,0 +1,731 @@ +#! /usr/bin/env python +""" +Script to tune demixing parameters by sampling data +""" + +import math,sys,uuid +import subprocess as sb +import time,glob +import numpy as np +import argparse +import itertools +import casacore.tables as ctab +from casacore.measures import measures +from casacore.quanta import quantity +from multiprocessing import Pool +from multiprocessing import shared_memory +import astropy.time as atime +import lsmtool +import logging + + + +# default DP3 +default_DP3='DP3' +# LINC script to download target sky +default_LINC_GET_TARGET='download_skymodel_target.py' + +def globalize(func): + def result(*args, **kwargs): + return func(*args, **kwargs) + result.__name__ = result.__qualname__ = uuid.uuid4().hex + setattr(sys.modules[result.__module__], result.__name__, result) + return result + +class DemixingTuner: + DP3=default_DP3 + LINC_GET_TARGET=default_LINC_GET_TARGET + def __init__(self,mslist,skymodel,timesec,Nf=3,Nparallel=4,path_DP3=None,path_LINC=None, patch_list=[]): + self.mslist=glob.glob(mslist) + self.timesec=timesec + self.Nf=Nf + self.skymodel=skymodel + # max iterations to consider + self.maxiter=[20, 30, 40, 50] + # robust degrees of freedom to consider + self.dof=[5, 20, 200] + if path_DP3: + self.DP3=path_DP3 + if path_LINC: + self.LINC_GET_TARGET=path_LINC + self.extracted_ms=[] + # stations + self.N=0 + # outlier directions + self.K=0 + # baselines + self.B=0 + # frequencies + self.freqlist=None + # bandwidth (of each subband) + self.bandwidth=0 + # best bandwidth (freq resolution) to use + self.bandwidth_to_use=0 + # best value for maximum iterations to use + self.maxiter_to_use=0 + # best value for robust DOF + self.robust_dof_to_use=0 + # number of channels (of each subband) + self.nchan=0 + # target direction + self.ra0=0 + self.dec0=0 + # measures object + self.mydm=None + # outlier directions (arrays) + self.ateam_ra=None + self.ateam_dec=None + # outlier separations,azimuth,elevations + self.separation=None + self.azimuth=None + self.elevation=None + # lowest elevation (degrees) to ignore any outlier source + self.low_elevation_limit=5 + # outlier (Patch) names + self.ateam_names=None + self.chosen_patches=None + self.given_patches=patch_list + # 1: only select to minimum number of sources, increase to include more + self.inclusiveness=2 + + self.target_skymodel='./target.sky.txt' + self.final_skymodel='./combined.sky.txt' + self.parset_demix='tunedemix.parset' + self.Nparallel=Nparallel + # sqrt(residual sum of squares) + self.RSS=None + self.AIC=None + self.RSS_freq=None + self.AIC_freq=None + self.RSS_iter=None + self.AIC_iter=None + self.RSS_dof=None + self.AIC_dof=None + + def __del__(self): + for ms in self.extracted_ms: + sb.run('rm -rf '+ms,shell=True) + + def extract_dataset(self): + self.mslist.sort() + msname=self.mslist[0] + tt=ctab.table(msname,readonly=True) + starttime= tt[0]['TIME'] + endtime=tt[tt.nrows()-1]['TIME'] + self.N=tt.nrows() + tt.close() + Nms=len(self.mslist) + + # need to have at least Nf MS + assert(Nms>=self.Nf) + + # Parset for extracting and averaging + parset_sample='extract_sample.parset' + parset=open(parset_sample,'w+') + + # sample time interval + t_start=np.random.rand()*(endtime-starttime)+starttime + t_end=t_start+self.timesec + t0=atime.Time(t_start/(24*60*60),format='mjd',scale='utc') + dt=t0.to_datetime() + str_tstart=str(dt.year)+'/'+str(dt.month)+'/'+str(dt.day)+'/'+str(dt.hour)+':'+str(dt.minute)+':'+str(dt.second) + t0=atime.Time(t_end/(24*60*60),format='mjd',scale='utc') + dt=t0.to_datetime() + str_tend=str(dt.year)+'/'+str(dt.month)+'/'+str(dt.day)+'/'+str(dt.hour)+':'+str(dt.minute)+':'+str(dt.second) + + parset.write('steps=[fil,avg]\n' + +'fil.type=filter\n' + +'fil.baseline=[CR]S*&\n' + +'fil.remove=True\n' + +'avg.type=average\n' + +'avg.timestep=1\n' + +'avg.freqstep=1\n' + +'msin.datacolumn=DATA\n' + +'msin.starttime='+str_tstart+'\n' + +'msin.endtime='+str_tend+'\n') + parset.close() + + # process subset of MS from mslist + submslist=list() + submslist.append(self.mslist[0]) + aa=list(np.random.choice(np.arange(1,Nms-1),self.Nf-2,replace=False)) + aa.sort() + for ms_id in aa: + submslist.append(self.mslist[ms_id]) + submslist.append(self.mslist[-1]) + + # remove old files + sb.run('rm -rf L_SB*.MS',shell=True) + # now process each of selected MS + self.extracted_ms=list() + for ci in range(self.Nf): + MS='L_SB'+str(ci)+'.MS' + proc1=sb.Popen(self.DP3+' '+parset_sample+' msin='+submslist[ci]+' msout='+MS, shell=True) + proc1.wait() + self.extracted_ms.append(MS) + # also update the weights + self.reset_weights(MS) + + # get metadata from extracted MS + self.freqlist=np.zeros(self.Nf) + for ms in self.extracted_ms: + tf=ctab.table(msname+'/SPECTRAL_WINDOW',readonly=True) + ch0=tf.getcol('CHAN_FREQ') + reffreq=tf.getcol('REF_FREQUENCY') + self.freqlist[ci]=ch0[0,0] + tf.close() + + tf=ctab.table(self.extracted_ms[0]+'/SPECTRAL_WINDOW',readonly=True) + self.bandwidth=tf.getcol('TOTAL_BANDWIDTH')[0] + self.nchan=tf.getcol('NUM_CHAN')[0] + tf.close() + + # get antennas + ant=ctab.table(self.extracted_ms[0]+'/ANTENNA',readonly=True) + self.N=ant.nrows() + # baselines + self.B=self.N*(self.N-1)//2 + # Antenna location + xyz=ant.getcol('POSITION') + ant.close() + # Get target coords + field=ctab.table(self.extracted_ms[0]+'/FIELD',readonly=True) + phase_dir=field.getcol('PHASE_DIR') + self.ra0=phase_dir[0][0][0] + self.dec0=phase_dir[0][0][1] + field.close() + + # get integration time + tt=ctab.table(self.extracted_ms[0],readonly=True) + tt1=tt.getcol('INTERVAL') + Tdelta=tt[0]['INTERVAL'] + t0=tt[0]['TIME'] + tt.close() + Tslots=math.ceil(self.timesec/Tdelta) + + # epoch coordinate UTC + self.mydm=measures() + mypos=self.mydm.position('ITRF',str(xyz[0][0])+'m',str(xyz[0][1])+'m',str(xyz[0][2])+'m') + mytime=self.mydm.epoch('UTC',str(t0)+'s') + self.mydm.doframe(mytime) + self.mydm.doframe(mypos) + + + def read_skymodel(self): + lsm=lsmtool.load(self.skymodel) + skymodel_patches=lsm.getPatchNames() + # only select patches already given by user + if len(self.given_patches)>0: + patches=list(set(self.given_patches).intersection(skymodel_patches)) + else: + patches=skymodel_patches + assert(len(patches)>0) + self.ateam_ra=np.zeros(len(patches)) + self.ateam_dec=np.zeros(len(patches)) + self.K=len(patches) + self.ateam_names=patches + for ci in range(self.K): + patch=patches[ci] + lsm=lsmtool.load(self.skymodel) + lsm.select('Patch=='+patch) + t=lsm.table + self.ateam_ra[ci]=np.array(t['Ra'])[0] + self.ateam_dec[ci]=np.array(t['Dec'])[0] + self.separation=np.zeros(self.K) + self.azimuth=np.zeros(self.K) + self.elevation=np.zeros(self.K) + ra0_q=quantity(self.ra0,'rad') + dec0_q=quantity(self.dec0,'rad') + target=self.mydm.direction('j2000',ra0_q,dec0_q) + for ci in range(self.K): + mra_q=quantity(self.ateam_ra[ci],'deg') + mdec_q=quantity(self.ateam_dec[ci],'deg') + cluster_dir=self.mydm.direction('j2000',mra_q,mdec_q) + sep=self.mydm.separation(target,cluster_dir) + self.separation[ci]=sep.get_value() + azel=self.mydm.measure(cluster_dir,'AZEL') + self.azimuth[ci]=azel['m0']['value']/math.pi*180 + self.elevation[ci]=azel['m1']['value']/math.pi*180 + + def get_target_skymodel(self,overwrite=True): + if overwrite: + sb.run('rm -rvf '+self.target_skymodel,shell=True) + sb.run('python '+self.LINC_GET_TARGET+' --Radius 5 --targetname CENTER '+self.extracted_ms[0]+' '+self.target_skymodel,shell=True) + + # convert integer to binary bits, return array of size K + def scalar_to_kvec(self,n): + ll=[1 if digit=='1' else 0 for digit in bin(n)[2:]] + a=np.zeros(self.K) + a[-len(ll):]=ll + return a + + # merge target with outliers given by + def create_final_skymodel(self): + lsm0=lsmtool.load(self.target_skymodel) + lsm=lsmtool.load(self.skymodel) + lsm0.concatenate(lsm,keep='all') + lsm0.write(self.final_skymodel,clobber=True) + + # merge target with outliers given by + # category index 'index' + def create_parset_dir(self,index): + chosen_dirs=self.scalar_to_kvec(index) + chosen_patches=itertools.compress(self.ateam_names,chosen_dirs) + parset_name=str(index)+'_'+self.parset_demix + bbsdem=open(parset_name,'w+') + bbsdem.write('steps=[demix]\n' + +'demix.type=demixer\n' + +'demix.baseline=[CR]S*&\n' + +'demix.ignoretarget=False\n' + +'demix.targetsource=\"CENTER\"\n' + +'demix.demixtimeresolution=10\n' + +'demix.demixfreqresolution=50kHz\n' + +'demix.uselbfgssolver=true\n' + +'demix.lbfgs.historysize=10\n' + +'demix.lbfgs.robustdof=200\n' + +'demix.freqstep=1\n' + +'demix.timestep=1\n' + +'demix.instrumentmodel=instrument_'+str(index)+'\n' + +'demix.skymodel='+self.final_skymodel+'\n' + ) + + bbsdem.write('demix.subtractsources=[') + firstcomma=False + if sum(chosen_dirs)>0: + for patch in chosen_patches: + if not firstcomma: + firstcomma=True + else: + bbsdem.write(',') + bbsdem.write('\"'+patch+'\"') + bbsdem.write(']\n') + bbsdem.close() + return parset_name + + def create_parset_freq(self,channel_index,channels): + bandwidth=self.bandwidth/self.nchan*channels[channel_index]/1e3 + parset_name=str(channel_index)+'_'+self.parset_demix + bbsdem=open(parset_name,'w+') + bbsdem.write('steps=[demix]\n' + +'demix.type=demixer\n' + +'demix.baseline=[CR]S*&\n' + +'demix.ignoretarget=False\n' + +'demix.targetsource=\"CENTER\"\n' + +'demix.demixtimeresolution=10\n' + +'demix.demixfreqresolution='+str(bandwidth)+'kHz\n' + +'demix.uselbfgssolver=true\n' + +'demix.lbfgs.historysize=10\n' + +'demix.lbfgs.robustdof=200\n' + +'demix.freqstep=1\n' + +'demix.timestep=1\n' + +'demix.instrumentmodel=instrument_'+str(channel_index)+'\n' + +'demix.skymodel='+self.final_skymodel+'\n' + ) + + bbsdem.write('demix.subtractsources=[') + firstcomma=False + for patch in self.chosen_patches: + if not firstcomma: + firstcomma=True + else: + bbsdem.write(',') + bbsdem.write('\"'+patch+'\"') + bbsdem.write(']\n') + bbsdem.close() + return parset_name + + def create_parset_iter(self,iter_index): + parset_name=str(iter_index)+'_'+self.parset_demix + bbsdem=open(parset_name,'w+') + bbsdem.write('steps=[demix]\n' + +'demix.type=demixer\n' + +'demix.baseline=[CR]S*&\n' + +'demix.ignoretarget=False\n' + +'demix.targetsource=\"CENTER\"\n' + +'demix.demixtimeresolution=10\n' + +'demix.demixfreqresolution='+str(self.bandwidth_to_use)+'\n' + +'demix.uselbfgssolver=true\n' + +'demix.lbfgs.historysize=10\n' + +'demix.lbfgs.robustdof='+str(self.robust_dof_to_use)+'\n' + +'demix.maxiter='+str(self.maxiter[iter_index])+'\n' + +'demix.freqstep=1\n' + +'demix.timestep=1\n' + +'demix.instrumentmodel=instrument_'+str(iter_index)+'\n' + +'demix.skymodel='+self.final_skymodel+'\n' + ) + + bbsdem.write('demix.subtractsources=[') + firstcomma=False + for patch in self.chosen_patches: + if not firstcomma: + firstcomma=True + else: + bbsdem.write(',') + bbsdem.write('\"'+patch+'\"') + bbsdem.write(']\n') + bbsdem.close() + return parset_name + + def create_parset_dof(self,dof_index): + parset_name=str(dof_index)+'_'+self.parset_demix + bbsdem=open(parset_name,'w+') + bbsdem.write('steps=[demix]\n' + +'demix.type=demixer\n' + +'demix.baseline=[CR]S*&\n' + +'demix.ignoretarget=False\n' + +'demix.targetsource=\"CENTER\"\n' + +'demix.demixtimeresolution=10\n' + +'demix.demixfreqresolution='+str(self.bandwidth_to_use)+'\n' + +'demix.uselbfgssolver=true\n' + +'demix.lbfgs.historysize=10\n' + +'demix.lbfgs.robustdof='+str(self.dof[dof_index])+'\n' + +'demix.maxiter=50\n' + +'demix.freqstep=1\n' + +'demix.timestep=1\n' + +'demix.instrumentmodel=instrument_'+str(dof_index)+'\n' + +'demix.skymodel='+self.final_skymodel+'\n' + ) + + bbsdem.write('demix.subtractsources=[') + firstcomma=False + for patch in self.chosen_patches: + if not firstcomma: + firstcomma=True + else: + bbsdem.write(',') + bbsdem.write('\"'+patch+'\"') + bbsdem.write(']\n') + bbsdem.close() + return parset_name + + + def iter_over_directions(self): + self.RSS=np.zeros((self.Nf,2**self.K)) + shmRSS=shared_memory.SharedMemory(create=True,size=self.RSS.nbytes) + RSSsh=np.ndarray(self.RSS.shape,dtype=self.RSS.dtype,buffer=shmRSS.buf) + + @globalize + def process_scenario(index): + # if any of the sources in this index has -ve elevation, + # stop and return a higher error + chosen_dirs=self.scalar_to_kvec(index) + low_elevations=any(ele <= self.low_elevation_limit for ele in itertools.compress(self.elevation,chosen_dirs)) + if not low_elevations: + parset_name=self.create_parset_dir(index) + # copy data + ci=0 + for ms in self.extracted_ms: + out_ms='out_'+str(index)+'_'+ms + sb.run('rm -rf '+out_ms,shell=True) + proc1=sb.Popen('rsync -a '+ms+'/ '+out_ms,shell=True) + proc1.wait() + demixout_ms='residual_'+out_ms + sb.run('rm -rf '+demixout_ms,shell=True) + proc1=sb.Popen(self.DP3+' '+parset_name+' msin='+out_ms+' msout='+demixout_ms,shell=True) + proc1.wait() + res_sigmaI,res_sigmaQ,res_sigmaU,res_sigmaV=self.get_noise_var(demixout_ms) + print('%s %d %f %f %f %f'%(out_ms,index,res_sigmaI,res_sigmaQ,res_sigmaU,res_sigmaV)) + RSSsh[ci,index]=res_sigmaI + ci+=1 + else: + RSSsh[:,index]=1e9 + + pool=Pool(self.Nparallel) + pool.map(process_scenario,range(1,2**self.K)) + pool.close() + pool.join() + + self.RSS[:]=RSSsh[:] + shmRSS.close() + shmRSS.unlink() + ci=0 + for ms in self.extracted_ms: + sigmaI,sigmaQ,sigmaU,sigmaV=self.get_noise_var(ms) + self.RSS[ci,0]=sigmaI + ci+=1 + + def iter_over_frequency(self): + assert(len(self.chosen_patches)>0) + # channels to consider + channels=[2**x for x in range(int(np.ceil(np.log(self.nchan)/np.log(2)))+1)] + self.RSS_freq=np.zeros((self.Nf,len(channels))) + shmRSS=shared_memory.SharedMemory(create=True,size=self.RSS_freq.nbytes) + RSSsh=np.ndarray(self.RSS_freq.shape,dtype=self.RSS_freq.dtype,buffer=shmRSS.buf) + + @globalize + def process_scenario_freq(channel_index): + parset_name=self.create_parset_freq(channel_index,channels) + # copy data + ci=0 + for ms in self.extracted_ms: + out_ms='out_'+str(channel_index)+'_'+ms + sb.run('rm -rf '+out_ms,shell=True) + proc1=sb.Popen('rsync -a '+ms+'/ '+out_ms,shell=True) + proc1.wait() + demixout_ms='residual_'+out_ms + sb.run('rm -rf '+demixout_ms,shell=True) + proc1=sb.Popen(self.DP3+' '+parset_name+' msin='+out_ms+' msout='+demixout_ms,shell=True) + proc1.wait() + res_sigmaI,res_sigmaQ,res_sigmaU,res_sigmaV=self.get_noise_var(demixout_ms) + print('%s %d %f %f %f %f'%(out_ms,channel_index,res_sigmaI,res_sigmaQ,res_sigmaU,res_sigmaV)) + RSSsh[ci,channel_index]=res_sigmaI + ci+=1 + + pool=Pool(self.Nparallel) + pool.map(process_scenario_freq,range(len(channels))) + pool.close() + pool.join() + + self.RSS_freq[:]=RSSsh[:] + shmRSS.close() + shmRSS.unlink() + + def iter_over_maxiter(self): + assert(len(self.chosen_patches)>0) + assert(len(self.maxiter)>0) + self.RSS_iter=np.zeros((self.Nf,len(self.maxiter))) + shmRSS=shared_memory.SharedMemory(create=True,size=self.RSS_iter.nbytes) + RSSsh=np.ndarray(self.RSS_iter.shape,dtype=self.RSS_iter.dtype,buffer=shmRSS.buf) + + @globalize + def process_scenario_iter(iter_index): + parset_name=self.create_parset_iter(iter_index) + # copy data + ci=0 + for ms in self.extracted_ms: + out_ms='out_'+str(iter_index)+'_'+ms + sb.run('rm -rf '+out_ms,shell=True) + proc1=sb.Popen('rsync -a '+ms+'/ '+out_ms,shell=True) + proc1.wait() + demixout_ms='residual_'+out_ms + sb.run('rm -rf '+demixout_ms,shell=True) + proc1=sb.Popen(self.DP3+' '+parset_name+' msin='+out_ms+' msout='+demixout_ms,shell=True) + proc1.wait() + res_sigmaI,res_sigmaQ,res_sigmaU,res_sigmaV=self.get_noise_var(demixout_ms) + print('%s %d %f %f %f %f'%(out_ms,iter_index,res_sigmaI,res_sigmaQ,res_sigmaU,res_sigmaV)) + RSSsh[ci,iter_index]=res_sigmaI + ci+=1 + + pool=Pool(self.Nparallel) + pool.map(process_scenario_iter,range(len(self.maxiter))) + pool.close() + pool.join() + + self.RSS_iter[:]=RSSsh[:] + shmRSS.close() + shmRSS.unlink() + + + def iter_over_robust_dof(self): + assert(len(self.chosen_patches)>0) + assert(len(self.dof)>0) + self.RSS_dof=np.zeros((self.Nf,len(self.dof))) + shmRSS=shared_memory.SharedMemory(create=True,size=self.RSS_dof.nbytes) + RSSsh=np.ndarray(self.RSS_dof.shape,dtype=self.RSS_dof.dtype,buffer=shmRSS.buf) + @globalize + def process_scenario_dof(iter_index): + parset_name=self.create_parset_dof(iter_index) + # copy data + ci=0 + for ms in self.extracted_ms: + out_ms='out_'+str(iter_index)+'_'+ms + sb.run('rm -rf '+out_ms,shell=True) + proc1=sb.Popen('rsync -a '+ms+'/ '+out_ms,shell=True) + proc1.wait() + demixout_ms='residual_'+out_ms + sb.run('rm -rf '+demixout_ms,shell=True) + proc1=sb.Popen(self.DP3+' '+parset_name+' msin='+out_ms+' msout='+demixout_ms,shell=True) + proc1.wait() + res_sigmaI,res_sigmaQ,res_sigmaU,res_sigmaV=self.get_noise_var(demixout_ms) + print('%s %d %f %f %f %f'%(out_ms,iter_index,res_sigmaI,res_sigmaQ,res_sigmaU,res_sigmaV)) + RSSsh[ci,iter_index]=res_sigmaI + ci+=1 + + pool=Pool(self.Nparallel) + pool.map(process_scenario_dof,range(len(self.dof))) + pool.close() + pool.join() + + self.RSS_dof[:]=RSSsh[:] + shmRSS.close() + shmRSS.unlink() + + # reset weights to 1, after applying weight to datum + def reset_weights(self,msname): + tt=ctab.table(msname,readonly=False) + data=tt.getcol('DATA') + if 'WEIGHT_SPECTRUM' in tt.colnames(): + weight=tt.getcol('WEIGHT_SPECTRUM') + else: + weight=tt.getcol('IMAGING_WEIGHT') + # weight ~sigma^2, std(weight) ~ sigma^2, so sqrt() + weight_std=np.sqrt(weight.std()) + data *=weight_std + weight /=weight_std + tt.putcol('DATA',data) + tt.putcol('WEIGHT_SPECTRUM',weight) + tt.close() + + # extract noise info + def get_noise_var(self,msname): + tt=ctab.table(msname,readonly=True) + t1=tt.query('ANTENNA1 != ANTENNA2',columns='DATA,FLAG') + data0=t1.getcol('DATA') + flag=t1.getcol('FLAG') + data=data0*(1-flag) + tt.close() + # set nans to 0 + data[np.isnan(data)]=0. + # form IQUV + sI=(data[:,:,0]+data[:,:,3])*0.5 + sQ=(data[:,:,0]-data[:,:,3])*0.5 + sU=(data[:,:,1]-data[:,:,2])*0.5 + sV=(data[:,:,1]+data[:,:,2])*0.5 + return sI.std(),sQ.std(),sU.std(),sV.std() + + # calculate AIC = RSS * N_stat *N_stat + K N_stat + # AIC = (noise RSS/ data RSS) * N_stat *N_stat + K * N_stat + def calc_AIC(self): + self.AIC=np.zeros(2**self.K) + for index in range(2**self.K): + chosen_dirs=self.scalar_to_kvec(index) + rss=(np.mean(self.RSS[:,index])/np.mean(self.RSS[:,0]))**2 + self.AIC[index]=rss*self.N*self.N+np.sum(chosen_dirs)*self.N + + indx=np.argsort(self.AIC[1:])+1 + chosen_dirs=np.zeros(self.K) + for index in indx[:self.inclusiveness]: + chosen_dirs+=self.scalar_to_kvec(index) + # remove -ve elevation dirs, again, to be safe + chosen_dirs[self.elevation<=self.low_elevation_limit]=0 + self.chosen_patches=list(itertools.compress(self.ateam_names,chosen_dirs)) + + def calc_bandwidth_to_use(self): + channels=[2**x for x in range(int(np.ceil(np.log(self.nchan)/np.log(2)))+1)] + for ci in range(self.Nf): + self.RSS_freq[ci,:]/=self.RSS[ci,0] + rss=np.mean(self.RSS_freq,axis=0)**2 + self.AIC_freq=rss*self.N*self.N+self.nchan/channels*self.N*len(self.chosen_patches) + index=np.argmin(self.AIC_freq) + self.bandwidth_to_use=self.bandwidth*channels[index]/self.nchan + + + def calc_maxiter_to_use(self): + for ci in range(self.Nf): + self.RSS_iter[ci,:]/=self.RSS[ci,0] + rss=np.mean(self.RSS_iter,axis=0)**2 + self.AIC_iter=rss*self.N*self.N+self.N*len(self.chosen_patches)*np.array(self.maxiter) + index=np.argmin(self.AIC_iter) + self.maxiter_to_use=self.maxiter[index] + + def calc_dof_to_use(self): + for ci in range(self.Nf): + self.RSS_dof[ci,:]/=self.RSS[ci,0] + rss=np.mean(self.RSS_dof,axis=0)**2 + self.AIC_dof=rss*self.N*self.N + index=np.argmin(self.AIC_dof) + self.robust_dof_to_use=self.dof[index] + + def report(self): + print('*************************') + print('separation') + print(self.separation) + print('elevation') + print(self.elevation) + print('direction selection') + print(self.RSS) + print(self.AIC) + print('frequency resolution selection') + print(self.RSS_freq) + print(self.AIC_freq) + print('robust DOF selection') + print(self.RSS_dof) + print(self.AIC_dof) + print('maxiter selection') + print(self.RSS_iter) + print(self.AIC_iter) + print('*************************') + + def cleanup(self): + for ms in self.extracted_ms: + out_ms='out_*_'+ms + sb.run('rm -rf '+out_ms,shell=True) + demixout_ms='residual_'+out_ms + sb.run('rm -rf '+demixout_ms,shell=True) + + parset_name='*_'+self.parset_demix + sb.run('rm -rf '+parset_name,shell=True) + sb.run('rm -rf instrument_*',shell=True) + + def run(self): + # Note: the order of following being run is important + self.extract_dataset() + self.read_skymodel() + self.get_target_skymodel(overwrite=True) + self.create_final_skymodel() + self.iter_over_directions() + self.cleanup() + self.calc_AIC() + self.iter_over_frequency() + self.calc_bandwidth_to_use() + self.cleanup() + self.iter_over_robust_dof() + self.calc_dof_to_use() + self.cleanup() + self.iter_over_maxiter() + self.calc_maxiter_to_use() + self.cleanup() + +def main(args): + tuner=DemixingTuner(args.MS,args.sky_model,args.time_interval,args.subbands,args.parallel_jobs,path_DP3=args.DP3, path_LINC=args.LINC, patch_list=args.patches) + tuner.run() + + result={} + result['maxiter']=tuner.maxiter_to_use + result['demixfreqresolution']=tuner.bandwidth_to_use + result['subtractsources']=tuner.chosen_patches + result['lbfgs.robustdof']=tuner.robust_dof_to_use + result['demixtimeresolution']=10 + result['lbfgs.historysize']=10 + + return result + + +if __name__=='__main__': + parser=argparse.ArgumentParser( + description='Tune demixing parameters', + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument('--MS',type=str,metavar='\'*.MS\'', + help='absolute path of MS pattern to use') + parser.add_argument('--sky_model',type=str,metavar='s', + help='A-Team sky model (text)') + parser.add_argument('--subbands',default=3,type=int,metavar='f', + help='number of subbands to process (minimum 3)') + parser.add_argument('--time_interval',type=float,default=30.,metavar='t', + help='total time interval to sample in seconds') + parser.add_argument('--parallel_jobs',type=int,default=4,metavar='p', + help='number of parallel jobs') + parser.add_argument('--patches',nargs='+',default=[], + metavar='Patch1', + help='names of patches to subtract if found in sky model, if None, all patches') + parser.add_argument('--DP3',type=str,metavar='DP3',default=default_DP3, + help='path to DP3 command') + parser.add_argument('--LINC',type=str,metavar='LINC',default=default_LINC_GET_TARGET, + help='path to download_skymodel_target.py script') + + + args=parser.parse_args() + format_stream = logging.Formatter("%(asctime)s\033[1m %(levelname)s:\033[0m %(message)s","%Y-%m-%d %H:%M:%S") + format_file = logging.Formatter("%(asctime)s %(levelname)s: %(message)s","%Y-%m-%d %H:%M:%S") + logging.root.setLevel(logging.INFO) + + log = logging.StreamHandler() + log.setFormatter(format_stream) + logging.root.addHandler(log) + + if args.MS and args.sky_model: + res=main(args) + else: + parser.print_help()