#! /usr/bin/env python """ Script to plot demixing solutions, (norm of solutions) """ import math,sys,uuid,os import glob import argparse import numpy as np import casacore.tables as ctab import subprocess as sb from multiprocessing import Pool from multiprocessing import shared_memory import matplotlib as mpl mpl.use('Agg') import matplotlib.pyplot as plt def globalize(func): def result(*args, **kwargs): return func(*args, **kwargs) result.__name__ = result.__qualname__ = uuid.uuid4().hex setattr(sys.modules[result.__module__], result.__name__, result) return result class PlotGenerator: def __init__(self,in_soltables,clipval): self.in_soltab=sorted(glob.glob(in_soltables)) self.clipval=clipval # initialize values self.N=0 self.K=0 self.B=0 self.T=0 self.F=0 self.directions=None self.Nparallel=4 self.Jnorm=None def read_solutions(self): self.B=len(self.in_soltab) assert(self.B>0) # open one table tt=ctab.table(self.in_soltab[0],readonly=True) vl=tt.getcol('VALUES') n_prod=vl.shape[0] self.T=vl.shape[1] self.F=vl.shape[2] sol_names_tab=ctab.table(tt.getkeyword('NAMES'),readonly=True) sol_names=sol_names_tab.getcol('NAME') self.directions=self.get_directions(sol_names) self.directions.reverse() self.K=len(self.directions) self.N=n_prod//(8*self.K) print(f'Processing {self.B} subbands, {self.N} stations, {self.K} directions, time {self.T} freq {self.F}') tt.close() self.Jnorm=np.zeros((self.K,self.N,self.T,self.F*self.B)) shmJnorm=shared_memory.SharedMemory(create=True,size=self.Jnorm.nbytes) Jnorm_sh=np.ndarray(self.Jnorm.shape,dtype=self.Jnorm.dtype,buffer=shmJnorm.buf) @globalize def process_table(index): tt=ctab.table(self.in_soltab[index],readonly=True) vl=tt.getcol('VALUES') # vl has shape : 8*self.K*self.N x self.T x self.F # from each 8 values, form a 2x2 matrix if (vl.shape==(8*self.K*self.N,self.T,self.F)): # process for ci in range(self.K): for cj in range(self.N): Jnorm_sh[ci,cj,:,index*self.F:(index+1)*self.F]=np.sqrt(np.sum(np.square(vl[ci*8*self.N+cj*8:ci*8*self.N+(cj+1)*8]),axis=0)) tt.close() pool=Pool(self.Nparallel) pool.map(process_table,range(self.B)) pool.close() pool.join() self.Jnorm[:]=Jnorm_sh[:] shmJnorm.close() shmJnorm.unlink() # replace Nan/Inf, also clip self.Jnorm[np.where(~np.isfinite(self.Jnorm))]=0 # determine standard deviation for clipping J_std=self.Jnorm.std() self.Jnorm[np.where(self.Jnorm>self.clipval*J_std)]=self.clipval*J_std def plot_solutions(self): # determine shape of subplots nrows=7 ncols=(self.N+nrows-1)//nrows # iterate over directions for ndir in range(self.K): fig = plt.figure() gs = fig.add_gridspec(nrows, ncols, hspace=0, wspace=0) axs = gs.subplots(sharex='col', sharey='row') for ci in range(nrows): for cj in range(ncols): ck=ci*ncols+cj if ck<self.N: im=axs[ci,cj].imshow(self.Jnorm[ndir,ck],interpolation=None,aspect='auto') else: axs[ci,cj].axis('off') for ax in fig.get_axes(): ax.label_outer() cb_ax=fig.add_axes([0.9, 0.1, 0.02, 0.8]) cbar=fig.colorbar(im,cax=cb_ax) fig.suptitle('direction '+str(self.directions[ndir])) fig.supxlabel('subbands') fig.supylabel('time') plt.savefig(os.path.basename(self.in_soltab[0])+'_dir_'+str(ndir)+'.png') plt.close() def get_directions(self,solution_names): """ solution_names will include all solutions, like 'DirectionGain:x:y:Real/Imag:Station:Direction' parse this and find unique direction names """ sourcenames=[] for solname in solution_names: sourcenames.append(solname.split(':')[-1]) return list(set(sourcenames)) def main(args): pg=PlotGenerator(args.instrument_tables,args.clip) pg.read_solutions() pg.plot_solutions() if __name__=='__main__': parser=argparse.ArgumentParser( description='Plot demixing solutions', formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument('--instrument_tables',type=str,metavar='\'*instrument\'', help='absolute path (pattern) to match all instrument tables') parser.add_argument('--clip',type=float,default=10, help='clip values (multiplied by standard deviation) above this value before plotting') args=parser.parse_args() if args.instrument_tables: main(args) else: parser.print_help()