#! /usr/bin/env python
"""
Script to plot demixing solutions, (norm of solutions)
"""

import math,sys,uuid,os
import glob
import argparse
import numpy as np
import casacore.tables as ctab
import subprocess as sb
from multiprocessing import Pool
from multiprocessing import shared_memory

import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt

def globalize(func):
  def result(*args, **kwargs):
    return func(*args, **kwargs)
  result.__name__ = result.__qualname__ = uuid.uuid4().hex
  setattr(sys.modules[result.__module__], result.__name__, result)
  return result


class PlotGenerator:
    def __init__(self,in_soltables,clipval):
        self.in_soltab=sorted(glob.glob(in_soltables))
        self.clipval=clipval
        # initialize values
        self.N=0
        self.K=0
        self.B=0
        self.T=0
        self.F=0
        self.directions=None

        self.Nparallel=4
        self.Jnorm=None

    def read_solutions(self):
        self.B=len(self.in_soltab)
        assert(self.B>0)
        # open one table
        tt=ctab.table(self.in_soltab[0],readonly=True)
        vl=tt.getcol('VALUES')
        n_prod=vl.shape[0]
        self.T=vl.shape[1]
        self.F=vl.shape[2]
        sol_names_tab=ctab.table(tt.getkeyword('NAMES'),readonly=True)
        sol_names=sol_names_tab.getcol('NAME')
        self.directions=self.get_directions(sol_names)
        self.directions.reverse()
        self.K=len(self.directions)
        self.N=n_prod//(8*self.K)
        print(f'Processing {self.B} subbands, {self.N} stations, {self.K} directions, time {self.T} freq {self.F}')
        tt.close()

        self.Jnorm=np.zeros((self.K,self.N,self.T,self.F*self.B))
        shmJnorm=shared_memory.SharedMemory(create=True,size=self.Jnorm.nbytes)
        Jnorm_sh=np.ndarray(self.Jnorm.shape,dtype=self.Jnorm.dtype,buffer=shmJnorm.buf)

        @globalize
        def process_table(index):
            tt=ctab.table(self.in_soltab[index],readonly=True)
            vl=tt.getcol('VALUES')
            # vl has shape : 8*self.K*self.N x self.T x self.F
            # from each 8 values, form a 2x2 matrix
            if (vl.shape==(8*self.K*self.N,self.T,self.F)):
                # process
                for ci in range(self.K):
                    for cj in range(self.N):
                        Jnorm_sh[ci,cj,:,index*self.F:(index+1)*self.F]=np.sqrt(np.sum(np.square(vl[ci*8*self.N+cj*8:ci*8*self.N+(cj+1)*8]),axis=0))
            tt.close()


        pool=Pool(self.Nparallel)
        pool.map(process_table,range(self.B))
        pool.close()
        pool.join()

        self.Jnorm[:]=Jnorm_sh[:]
        shmJnorm.close()
        shmJnorm.unlink()
        # replace Nan/Inf, also clip
        self.Jnorm[np.where(~np.isfinite(self.Jnorm))]=0
        # determine standard deviation for clipping
        J_std=self.Jnorm.std()
        self.Jnorm[np.where(self.Jnorm>self.clipval*J_std)]=self.clipval*J_std


    def plot_solutions(self):
        # determine shape of subplots 
        nrows=7
        ncols=(self.N+nrows-1)//nrows

        # iterate over directions
        for ndir in range(self.K):
          fig = plt.figure()
          gs = fig.add_gridspec(nrows, ncols, hspace=0, wspace=0)
          axs = gs.subplots(sharex='col', sharey='row')

          for ci in range(nrows):
            for cj in range(ncols):
                ck=ci*ncols+cj
                if ck<self.N:
                  im=axs[ci,cj].imshow(self.Jnorm[ndir,ck],interpolation=None,aspect='auto')
                else:
                  axs[ci,cj].axis('off')

          for ax in fig.get_axes():
            ax.label_outer()
          cb_ax=fig.add_axes([0.9, 0.1, 0.02, 0.8])
          cbar=fig.colorbar(im,cax=cb_ax)
          fig.suptitle('direction '+str(self.directions[ndir]))
          fig.supxlabel('subbands')
          fig.supylabel('time')
          plt.savefig(os.path.basename(self.in_soltab[0])+'_dir_'+str(ndir)+'.png')
          plt.close()

    def get_directions(self,solution_names):
        """
         solution_names will include all solutions,
         like 'DirectionGain:x:y:Real/Imag:Station:Direction'
         parse this and find unique direction names
        """
        sourcenames=[]
        for solname in solution_names:
           sourcenames.append(solname.split(':')[-1])
        return list(set(sourcenames))

def main(args):
    pg=PlotGenerator(args.instrument_tables,args.clip)
    pg.read_solutions()
    pg.plot_solutions()


if __name__=='__main__':
    parser=argparse.ArgumentParser(
      description='Plot demixing solutions',
      formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    parser.add_argument('--instrument_tables',type=str,metavar='\'*instrument\'',
        help='absolute path (pattern) to match all instrument tables')
    parser.add_argument('--clip',type=float,default=10,
        help='clip values (multiplied by standard deviation) above this value before plotting')

    args=parser.parse_args()
    if args.instrument_tables:
      main(args)
    else:
      parser.print_help()