Skip to content
Snippets Groups Projects
Commit b1b9894c authored by Sarod Yatawatta's avatar Sarod Yatawatta
Browse files

Script to plot demixing solutions

parent 727fca3a
No related branches found
No related tags found
1 merge request!149Script to plot demixing solutions
Pipeline #52492 passed
#! /usr/bin/env python
"""
Script to plot demixing solutions, (norm of solutions)
"""
import math,sys,uuid,os
import glob
import argparse
import numpy as np
import casacore.tables as ctab
import subprocess as sb
from multiprocessing import Pool
from multiprocessing import shared_memory
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
def globalize(func):
def result(*args, **kwargs):
return func(*args, **kwargs)
result.__name__ = result.__qualname__ = uuid.uuid4().hex
setattr(sys.modules[result.__module__], result.__name__, result)
return result
class PlotGenerator:
def __init__(self,in_soltables,clipval):
self.in_soltab=sorted(glob.glob(in_soltables))
self.clipval=clipval
# initialize values
self.N=0
self.K=0
self.B=0
self.T=0
self.F=0
self.directions=None
self.Nparallel=4
self.Jnorm=None
def read_solutions(self):
self.B=len(self.in_soltab)
assert(self.B>0)
# open one table
tt=ctab.table(self.in_soltab[0],readonly=True)
vl=tt.getcol('VALUES')
n_prod=vl.shape[0]
self.T=vl.shape[1]
self.F=vl.shape[2]
sol_names_tab=ctab.table(tt.getkeyword('NAMES'),readonly=True)
sol_names=sol_names_tab.getcol('NAME')
self.directions=self.get_directions(sol_names)
self.directions.reverse()
self.K=len(self.directions)
self.N=n_prod//(8*self.K)
print(f'Processing {self.B} subbands, {self.N} stations, {self.K} directions, time {self.T} freq {self.F}')
tt.close()
self.Jnorm=np.zeros((self.K,self.N,self.T,self.F*self.B))
shmJnorm=shared_memory.SharedMemory(create=True,size=self.Jnorm.nbytes)
Jnorm_sh=np.ndarray(self.Jnorm.shape,dtype=self.Jnorm.dtype,buffer=shmJnorm.buf)
@globalize
def process_table(index):
tt=ctab.table(self.in_soltab[index],readonly=True)
vl=tt.getcol('VALUES')
# vl has shape : 8*self.K*self.N x self.T x self.F
# from each 8 values, form a 2x2 matrix
if (vl.shape==(8*self.K*self.N,self.T,self.F)):
# process
for ci in range(self.K):
for cj in range(self.N):
Jnorm_sh[ci,cj,:,index*self.F:(index+1)*self.F]=np.sqrt(np.sum(np.square(vl[ci*8*self.N+cj*8:ci*8*self.N+(cj+1)*8]),axis=0))
tt.close()
pool=Pool(self.Nparallel)
pool.map(process_table,range(self.B))
pool.close()
pool.join()
self.Jnorm[:]=Jnorm_sh[:]
shmJnorm.close()
shmJnorm.unlink()
# replace Nan/Inf, also clip
self.Jnorm[np.where(~np.isfinite(self.Jnorm))]=0
# determine standard deviation for clipping
J_std=self.Jnorm.std()
self.Jnorm[np.where(self.Jnorm>self.clipval*J_std)]=self.clipval*J_std
def plot_solutions(self):
# determine shape of subplots
nrows=7
ncols=(self.N+nrows-1)//nrows
# iterate over directions
for ndir in range(self.K):
fig = plt.figure()
gs = fig.add_gridspec(nrows, ncols, hspace=0, wspace=0)
axs = gs.subplots(sharex='col', sharey='row')
for ci in range(nrows):
for cj in range(ncols):
ck=ci*ncols+cj
if ck<self.N:
im=axs[ci,cj].imshow(self.Jnorm[ndir,ck],interpolation=None,aspect='auto')
else:
axs[ci,cj].axis('off')
for ax in fig.get_axes():
ax.label_outer()
cb_ax=fig.add_axes([0.9, 0.1, 0.02, 0.8])
cbar=fig.colorbar(im,cax=cb_ax)
fig.suptitle('direction '+str(self.directions[ndir]))
fig.supxlabel('subbands')
fig.supylabel('time')
plt.savefig(os.path.basename(self.in_soltab[0])+'_dir_'+str(ndir)+'.png')
plt.close()
def get_directions(self,solution_names):
"""
solution_names will include all solutions,
like 'DirectionGain:x:y:Real/Imag:Station:Direction'
parse this and find unique direction names
"""
sourcenames=[]
for solname in solution_names:
sourcenames.append(solname.split(':')[-1])
return list(set(sourcenames))
def main(args):
pg=PlotGenerator(args.instrument_tables,args.clip)
pg.read_solutions()
pg.plot_solutions()
if __name__=='__main__':
parser=argparse.ArgumentParser(
description='Plot demixing solutions',
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('--instrument_tables',type=str,metavar='\'*instrument\'',
help='absolute path (pattern) to match all instrument tables')
parser.add_argument('--clip',type=float,default=10,
help='clip values (multiplied by standard deviation) above this value before plotting')
args=parser.parse_args()
if args.instrument_tables:
main(args)
else:
parser.print_help()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment