diff --git a/scripts/plot_demixing_solutions.py b/scripts/plot_demixing_solutions.py new file mode 100755 index 0000000000000000000000000000000000000000..71d6aa88bf4596cbaf9f3737211441c45ae70c1c --- /dev/null +++ b/scripts/plot_demixing_solutions.py @@ -0,0 +1,155 @@ +#! /usr/bin/env python +""" +Script to plot demixing solutions, (norm of solutions) +""" + +import math,sys,uuid,os +import glob +import argparse +import numpy as np +import casacore.tables as ctab +import subprocess as sb +from multiprocessing import Pool +from multiprocessing import shared_memory + +import matplotlib as mpl +mpl.use('Agg') +import matplotlib.pyplot as plt + +def globalize(func): + def result(*args, **kwargs): + return func(*args, **kwargs) + result.__name__ = result.__qualname__ = uuid.uuid4().hex + setattr(sys.modules[result.__module__], result.__name__, result) + return result + + +class PlotGenerator: + def __init__(self,in_soltables,clipval): + self.in_soltab=sorted(glob.glob(in_soltables)) + self.clipval=clipval + # initialize values + self.N=0 + self.K=0 + self.B=0 + self.T=0 + self.F=0 + self.directions=None + + self.Nparallel=4 + self.Jnorm=None + + def read_solutions(self): + self.B=len(self.in_soltab) + assert(self.B>0) + # open one table + tt=ctab.table(self.in_soltab[0],readonly=True) + vl=tt.getcol('VALUES') + n_prod=vl.shape[0] + self.T=vl.shape[1] + self.F=vl.shape[2] + sol_names_tab=ctab.table(tt.getkeyword('NAMES'),readonly=True) + sol_names=sol_names_tab.getcol('NAME') + self.directions=self.get_directions(sol_names) + self.directions.reverse() + self.K=len(self.directions) + self.N=n_prod//(8*self.K) + print(f'Processing {self.B} subbands, {self.N} stations, {self.K} directions, time {self.T} freq {self.F}') + tt.close() + + self.Jnorm=np.zeros((self.K,self.N,self.T,self.F*self.B)) + shmJnorm=shared_memory.SharedMemory(create=True,size=self.Jnorm.nbytes) + Jnorm_sh=np.ndarray(self.Jnorm.shape,dtype=self.Jnorm.dtype,buffer=shmJnorm.buf) + + @globalize + def process_table(index): + tt=ctab.table(self.in_soltab[index],readonly=True) + vl=tt.getcol('VALUES') + # vl has shape : 8*self.K*self.N x self.T x self.F + # from each 8 values, form a 2x2 matrix + if (vl.shape==(8*self.K*self.N,self.T,self.F)): + # process + for ci in range(self.K): + for cj in range(self.N): + Jnorm_sh[ci,cj,:,index*self.F:(index+1)*self.F]=np.sqrt(np.sum(np.square(vl[ci*8*self.N+cj*8:ci*8*self.N+(cj+1)*8]),axis=0)) + tt.close() + + + pool=Pool(self.Nparallel) + pool.map(process_table,range(self.B)) + pool.close() + pool.join() + + self.Jnorm[:]=Jnorm_sh[:] + shmJnorm.close() + shmJnorm.unlink() + # replace Nan/Inf, also clip + self.Jnorm[np.where(~np.isfinite(self.Jnorm))]=0 + # determine standard deviation for clipping + J_std=self.Jnorm.std() + self.Jnorm[np.where(self.Jnorm>self.clipval*J_std)]=self.clipval*J_std + + + def plot_solutions(self): + # determine shape of subplots + nrows=7 + ncols=(self.N+nrows-1)//nrows + + # iterate over directions + for ndir in range(self.K): + fig = plt.figure() + gs = fig.add_gridspec(nrows, ncols, hspace=0, wspace=0) + axs = gs.subplots(sharex='col', sharey='row') + + for ci in range(nrows): + for cj in range(ncols): + ck=ci*ncols+cj + if ck<self.N: + im=axs[ci,cj].imshow(self.Jnorm[ndir,ck],interpolation=None,aspect='auto') + else: + axs[ci,cj].axis('off') + + for ax in fig.get_axes(): + ax.label_outer() + cb_ax=fig.add_axes([0.9, 0.1, 0.02, 0.8]) + cbar=fig.colorbar(im,cax=cb_ax) + fig.suptitle('direction '+str(self.directions[ndir])) + fig.supxlabel('subbands') + fig.supylabel('time') + plt.savefig(os.path.basename(self.in_soltab[0])+'_dir_'+str(ndir)+'.png') + plt.close() + + def get_directions(self,solution_names): + """ + solution_names will include all solutions, + like 'DirectionGain:x:y:Real/Imag:Station:Direction' + parse this and find unique direction names + """ + sourcenames=[] + for solname in solution_names: + sourcenames.append(solname.split(':')[-1]) + return list(set(sourcenames)) + +def main(args): + pg=PlotGenerator(args.instrument_tables,args.clip) + pg.read_solutions() + pg.plot_solutions() + + +if __name__=='__main__': + parser=argparse.ArgumentParser( + description='Plot demixing solutions', + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument('--instrument_tables',type=str,metavar='\'*instrument\'', + help='absolute path (pattern) to match all instrument tables') + parser.add_argument('--clip',type=float,default=10, + help='clip values (multiplied by standard deviation) above this value before plotting') + + args=parser.parse_args() + if args.instrument_tables: + main(args) + else: + parser.print_help() +