Skip to content
Snippets Groups Projects

add script for tuning demixing

Merged Sarod Yatawatta requested to merge RAP-90 into master
+ 418
0
#! /usr/bin/env python
"""
Script to tune demixing parameters by sampling data
"""
import math,sys,uuid
import subprocess as sb
import time,glob
import numpy as np
import argparse
import itertools
import casacore.tables as ctab
from casacore.measures import measures
from casacore.quanta import quantity
from multiprocessing import Pool
from multiprocessing import shared_memory
import astropy.time as atime
import lsmtool
# default DP3
default_DP3='DP3'
# LINC script to download target sky
default_LINC_GET_TARGET='download_skymodel_target.py'
def globalize(func):
def result(*args, **kwargs):
return func(*args, **kwargs)
result.__name__ = result.__qualname__ = uuid.uuid4().hex
setattr(sys.modules[result.__module__], result.__name__, result)
return result
class DemixingTuner:
DP3=default_DP3
LINC_GET_TARGET=default_LINC_GET_TARGET
def __init__(self,mslist,skymodel,timesec,Nf=3,Nparallel=4,path_DP3=None,path_LINC=None):
self.mslist=glob.glob(mslist)
self.timesec=timesec
self.Nf=Nf
self.skymodel=skymodel
if path_DP3:
self.DP3=path_DP3
if path_LINC:
self.LINC_GET_TARGET=path_LINC
self.extracted_ms=None
# stations
self.N=0
# outlier directions
self.K=0
# baselines
self.B=0
# frequencies
self.freqlist=None
# target direction
self.ra0=0
self.dec0=0
# measures object
self.mydm=None
# outlier directions (arrays)
self.ateam_ra=None
self.ateam_dec=None
# outlier separations,azimuth,elevations
self.separation=None
self.azimuth=None
self.elevation=None
# lowest elevation (degrees) to ignore any outlier source
self.low_elevation_limit=5
# outlier (Patch) names
self.ateam_names=None
self.target_skymodel='./target.sky.txt'
self.final_skymodel='./combined.sky.txt'
self.parset_demix='tunedemix.parset'
self.Nparallel=Nparallel
# sqrt(residual sum of squares)
self.RSS=None
self.AIC=None
def extract_dataset(self):
self.mslist.sort()
msname=self.mslist[0]
tt=ctab.table(msname,readonly=True)
starttime= tt[0]['TIME']
endtime=tt[tt.nrows()-1]['TIME']
self.N=tt.nrows()
tt.close()
Nms=len(self.mslist)
# need to have at least Nf MS
assert(Nms>=self.Nf)
# Parset for extracting and averaging
parset_sample='extract_sample.parset'
parset=open(parset_sample,'w+')
# sample time interval
t_start=np.random.rand()*(endtime-starttime)+starttime
t_end=t_start+self.timesec
t0=atime.Time(t_start/(24*60*60),format='mjd',scale='utc')
dt=t0.to_datetime()
str_tstart=str(dt.year)+'/'+str(dt.month)+'/'+str(dt.day)+'/'+str(dt.hour)+':'+str(dt.minute)+':'+str(dt.second)
t0=atime.Time(t_end/(24*60*60),format='mjd',scale='utc')
dt=t0.to_datetime()
str_tend=str(dt.year)+'/'+str(dt.month)+'/'+str(dt.day)+'/'+str(dt.hour)+':'+str(dt.minute)+':'+str(dt.second)
parset.write('steps=[fil,avg]\n'
+'fil.type=filter\n'
+'fil.baseline=[CR]S*&\n'
+'fil.remove=True\n'
+'avg.type=average\n'
+'avg.timestep=1\n'
+'avg.freqstep=1\n'
+'msin.datacolumn=DATA\n'
+'msin.starttime='+str_tstart+'\n'
+'msin.endtime='+str_tend+'\n')
parset.close()
# process subset of MS from mslist
submslist=list()
submslist.append(self.mslist[0])
aa=list(np.random.choice(np.arange(1,Nms-1),self.Nf-2,replace=False))
aa.sort()
for ms_id in aa:
submslist.append(self.mslist[ms_id])
submslist.append(self.mslist[-1])
# remove old files
sb.run('rm -rf L_SB*.MS',shell=True)
# now process each of selected MS
self.extracted_ms=list()
for ci in range(self.Nf):
MS='L_SB'+str(ci)+'.MS'
proc1=sb.Popen(self.DP3+' '+parset_sample+' msin='+submslist[ci]+' msout='+MS, shell=True)
proc1.wait()
self.extracted_ms.append(MS)
# also update the weights
self.reset_weights(MS)
# get metadata from extracted MS
self.freqlist=np.zeros(self.Nf)
for ms in self.extracted_ms:
tf=ctab.table(msname+'/SPECTRAL_WINDOW',readonly=True)
ch0=tf.getcol('CHAN_FREQ')
reffreq=tf.getcol('REF_FREQUENCY')
self.freqlist[ci]=ch0[0,0]
tf.close()
# get antennas
ant=ctab.table(self.extracted_ms[0]+'/ANTENNA',readonly=True)
self.N=ant.nrows()
# baselines
self.B=self.N*(self.N-1)//2
# Antenna location
xyz=ant.getcol('POSITION')
ant.close()
# Get target coords
field=ctab.table(self.extracted_ms[0]+'/FIELD',readonly=True)
phase_dir=field.getcol('PHASE_DIR')
self.ra0=phase_dir[0][0][0]
self.dec0=phase_dir[0][0][1]
field.close()
# get integration time
tt=ctab.table(self.extracted_ms[0],readonly=True)
tt1=tt.getcol('INTERVAL')
Tdelta=tt[0]['INTERVAL']
t0=tt[0]['TIME']
tt.close()
Tslots=math.ceil(self.timesec/Tdelta)
# epoch coordinate UTC
self.mydm=measures()
mypos=self.mydm.position('ITRF',str(xyz[0][0])+'m',str(xyz[0][1])+'m',str(xyz[0][2])+'m')
mytime=self.mydm.epoch('UTC',str(t0)+'s')
self.mydm.doframe(mytime)
self.mydm.doframe(mypos)
def read_skymodel(self):
lsm=lsmtool.load(self.skymodel)
patches=lsm.getPatchNames()
print(patches)
self.ateam_ra=np.zeros(len(patches))
self.ateam_dec=np.zeros(len(patches))
self.K=len(patches)
self.ateam_names=patches
for ci in range(self.K):
patch=patches[ci]
lsm=lsmtool.load(self.skymodel)
lsm.select('Patch=='+patch)
t=lsm.table
self.ateam_ra[ci]=np.array(t['Ra'])[0]
self.ateam_dec[ci]=np.array(t['Dec'])[0]
self.separation=np.zeros(self.K)
self.azimuth=np.zeros(self.K)
self.elevation=np.zeros(self.K)
ra0_q=quantity(self.ra0,'rad')
dec0_q=quantity(self.dec0,'rad')
target=self.mydm.direction('j2000',ra0_q,dec0_q)
for ci in range(self.K):
mra_q=quantity(self.ateam_ra[ci],'deg')
mdec_q=quantity(self.ateam_dec[ci],'deg')
cluster_dir=self.mydm.direction('j2000',mra_q,mdec_q)
sep=self.mydm.separation(target,cluster_dir)
self.separation[ci]=sep.get_value()
azel=self.mydm.measure(cluster_dir,'AZEL')
self.azimuth[ci]=azel['m0']['value']/math.pi*180
self.elevation[ci]=azel['m1']['value']/math.pi*180
def get_target_skymodel(self,overwrite=True):
if overwrite:
sb.run('rm -rvf '+self.target_skymodel,shell=True)
sb.run('python '+self.LINC_GET_TARGET+' --Radius 5 --targetname CENTER '+self.extracted_ms[0]+' '+self.target_skymodel,shell=True)
# convert integer to binary bits, return array of size K
def scalar_to_kvec(self,n):
ll=[1 if digit=='1' else 0 for digit in bin(n)[2:]]
a=np.zeros(self.K)
a[-len(ll):]=ll
return a
# merge target with outliers given by
def create_final_skymodel(self):
lsm0=lsmtool.load(self.target_skymodel)
lsm=lsmtool.load(self.skymodel)
lsm0.concatenate(lsm,keep='all')
lsm0.write(self.final_skymodel,clobber=True)
# merge target with outliers given by
# category index 'index'
def create_parset(self,index):
chosen_dirs=self.scalar_to_kvec(index)
chosen_patches=itertools.compress(self.ateam_names,chosen_dirs)
parset_name=str(index)+'_'+self.parset_demix
bbsdem=open(parset_name,'w+')
bbsdem.write('steps=[demix]\n'
+'demix.type=demixer\n'
+'demix.baseline=[CR]S*&\n'
+'demix.ignoretarget=False\n'
+'demix.targetsource=\"CENTER\"\n'
#+'demix.demixtimestep=10\n'
#+'demix.demixfreqstep=1\n'
+'demix.demixtimeresolution=10\n'
+'demix.demixfreqresolution=50kHz\n'
+'demix.uselbfgssolver=true\n'
+'demix.lbfgs.historysize=10\n'
+'demix.lbfgs.robustdof=200\n'
+'demix.freqstep=1\n'
+'demix.timestep=1\n'
+'demix.instrumentmodel=instrument_'+str(index)+'\n'
+'demix.skymodel='+self.final_skymodel+'\n'
)
bbsdem.write('demix.subtractsources=[')
firstcomma=False
if sum(chosen_dirs)>0:
for patch in chosen_patches:
if not firstcomma:
firstcomma=True
else:
bbsdem.write(',')
bbsdem.write('\"'+patch+'\"')
bbsdem.write(']\n')
bbsdem.close()
return parset_name
def iter_over_directions(self):
self.RSS=np.zeros((self.Nf,2**self.K))
shmRSS=shared_memory.SharedMemory(create=True,size=self.RSS.nbytes)
RSSsh=np.ndarray(self.RSS.shape,dtype=self.RSS.dtype,buffer=shmRSS.buf)
@globalize
def process_scenario(index):
# if any of the sources in this index has -ve elevation,
# stop and return a higher error
chosen_dirs=self.scalar_to_kvec(index)
low_elevations=any(ele <= self.low_elevation_limit for ele in itertools.compress(self.elevation,chosen_dirs))
if not low_elevations:
parset_name=self.create_parset(index)
# copy data
ci=0
for ms in self.extracted_ms:
out_ms='out_'+str(index)+'_'+ms
sb.run('rm -rf '+out_ms,shell=True)
proc1=sb.Popen('rsync -a '+ms+'/ '+out_ms,shell=True)
proc1.wait()
demixout_ms='residual_'+out_ms
sb.run('rm -rf '+demixout_ms,shell=True)
proc1=sb.Popen(self.DP3+' '+parset_name+' msin='+out_ms+' msout='+demixout_ms,shell=True)
proc1.wait()
res_sigmaI,res_sigmaQ,res_sigmaU,res_sigmaV=self.get_noise_var(demixout_ms)
print('%s %d %f %f %f %f'%(out_ms,index,res_sigmaI,res_sigmaQ,res_sigmaU,res_sigmaV))
#RSSsh[ci,index]=(res_sigmaI+res_sigmaQ+res_sigmaU+res_sigmaV)/4
RSSsh[ci,index]=res_sigmaI
ci+=1
else:
RSSsh[:,index]=1e9
pool=Pool(self.Nparallel)
pool.map(process_scenario,range(1,2**self.K))
pool.close()
pool.join()
self.RSS[:]=RSSsh[:]
shmRSS.close()
shmRSS.unlink()
ci=0
for ms in self.extracted_ms:
sigmaI,sigmaQ,sigmaU,sigmaV=self.get_noise_var(ms)
self.RSS[ci,0]=sigmaI
ci+=1
# reset weights to 1, after applying weight to datum
def reset_weights(self,msname):
tt=ctab.table(msname,readonly=False)
data=tt.getcol('DATA')
if 'WEIGHT_SPECTRUM' in tt.colnames():
weight=tt.getcol('WEIGHT_SPECTRUM')
else:
weight=tt.getcol('IMAGING_WEIGHT')
# weight ~sigma^2, std(weight) ~ sigma^2, so sqrt()
weight_std=np.sqrt(weight.std())
data *=weight_std
weight /=weight_std
tt.putcol('DATA',data)
tt.putcol('WEIGHT_SPECTRUM',weight)
tt.close()
# extract noise info
def get_noise_var(self,msname):
tt=ctab.table(msname,readonly=True)
t1=tt.query('ANTENNA1 != ANTENNA2',columns='DATA,FLAG')
data0=t1.getcol('DATA')
flag=t1.getcol('FLAG')
data=data0*(1-flag)
tt.close()
# set nans to 0
data[np.isnan(data)]=0.
# form IQUV
sI=(data[:,:,0]+data[:,:,3])*0.5
sQ=(data[:,:,0]-data[:,:,3])*0.5
sU=(data[:,:,1]-data[:,:,2])*0.5
sV=(data[:,:,1]+data[:,:,2])*0.5
return sI.std(),sQ.std(),sU.std(),sV.std()
# calculate AIC = RSS * N_stat *N_stat + K N_stat
# AIC = (noise RSS/ data RSS) * N_stat *N_stat + K * N_stat
def calc_AIC(self):
self.AIC=np.zeros(2**self.K)
for index in range(2**self.K):
chosen_dirs=self.scalar_to_kvec(index)
rss=(np.mean(self.RSS[:,index])/np.mean(self.RSS[:,0]))**2
self.AIC[index]=rss*self.N*self.N+np.sum(chosen_dirs)*self.N
def report(self):
indx=np.argsort(self.AIC[1:])+1
chosen_dirs=np.zeros(self.K)
for index in indx[:2]:
chosen_dirs+=self.scalar_to_kvec(index)
chosen_patches=itertools.compress(self.ateam_names,chosen_dirs)
print('***************** Result:')
for patch in chosen_patches:
print(patch)
print('*************************')
def cleanup(self):
for ms in self.extracted_ms:
out_ms='out_*_'+ms
sb.run('rm -rf '+out_ms,shell=True)
demixout_ms='residual_'+out_ms
sb.run('rm -rf '+demixout_ms,shell=True)
parset_name='*_'+self.parset_demix
sb.run('rm -rf '+parset_name,shell=True)
sb.run('rm -rf instrument_*',shell=True)
def run(self):
self.extract_dataset()
self.read_skymodel()
self.get_target_skymodel(overwrite=True)
self.create_final_skymodel()
self.iter_over_directions()
self.cleanup()
self.calc_AIC()
self.report()
if __name__=='__main__':
parser=argparse.ArgumentParser(
description='Tune demixing parameters',
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('--MS',type=str,metavar='\'*.MS\'',
help='absolute path of MS pattern to use')
parser.add_argument('--sky_model',type=str,metavar='s',
help='A-Team sky model (text)')
parser.add_argument('--subbands',default=3,type=int,metavar='f',
help='number of subbands to process (minimum 3)')
parser.add_argument('--time_interval',type=float,default=30.,metavar='t',
help='total time interval to sample in seconds')
parser.add_argument('--parallel_jobs',type=int,default=4,metavar='p',
help='number of parallel jobs')
parser.add_argument('--DP3',type=str,metavar='DP3',
help='path to DP3 command')
parser.add_argument('--LINC',type=str,metavar='LINC',
help='path to download_skymodel_target.py script')
args=parser.parse_args()
if args.MS and args.sky_model:
tuner=DemixingTuner(args.MS,args.sky_model,args.time_interval,args.subbands,args.parallel_jobs,path_DP3=args.DP3, path_LINC=args.LINC)
tuner.run()
else:
parser.print_help()
Loading