#! /usr/bin/env python
"""
Script to tune demixing parameters by sampling data
"""

import math,sys,uuid
import subprocess as sb
import time,glob
import numpy as np
import argparse
import itertools
import casacore.tables as ctab
from casacore.measures import measures
from casacore.quanta import quantity
from multiprocessing import Pool
from multiprocessing import shared_memory
import astropy.time as atime
import lsmtool
import logging



# default DP3
default_DP3='DP3'
# LINC script to download target sky
default_LINC_GET_TARGET='download_skymodel_target.py'

def globalize(func):
  def result(*args, **kwargs):
    return func(*args, **kwargs)
  result.__name__ = result.__qualname__ = uuid.uuid4().hex
  setattr(sys.modules[result.__module__], result.__name__, result)
  return result

class DemixingTuner:
    DP3=default_DP3
    LINC_GET_TARGET=default_LINC_GET_TARGET
    INF=1e15
    def __init__(self,mslist,skymodel,target_skymodel,timesec,Nf,sblist,Nparallel=4,path_DP3=None,path_LINC=None, patch_list=[]):
        self.mslist=glob.glob(mslist)
        self.Nf=Nf
        if len(sblist)>0:
            # extract subbands given by input
            new_mslist=self.select_subbands(self.mslist,sblist)
            self.mslist=new_mslist
            self.Nf=len(self.mslist)

        self.timesec=timesec
        self.skymodel=skymodel
        self.target_skymodel_in=target_skymodel
        # max iterations to consider
        self.maxiter=[20, 30, 40, 50]
        # robust degrees of freedom to consider
        self.dof=[5, 20, 200]
        if path_DP3:
            self.DP3=path_DP3
        if path_LINC:
            self.LINC_GET_TARGET=path_LINC
        self.extracted_ms=[]
        # stations
        self.N=0
        # outlier directions
        self.K=0
        # baselines
        self.B=0
        # frequencies
        self.freqlist=None
        # bandwidth (of each subband)
        self.bandwidth=0
        # best bandwidth (freq resolution) to use
        self.bandwidth_to_use=0
        # best value for maximum iterations to use
        self.maxiter_to_use=0
        # best value for robust DOF
        self.robust_dof_to_use=0
        # number of channels (of each subband)
        self.nchan=0
        # target direction
        self.ra0=0
        self.dec0=0
        # measures object
        self.mydm=None
        # outlier directions (arrays)
        self.ateam_ra=None
        self.ateam_dec=None
        # outlier separations,azimuth,elevations
        self.separation=None
        self.azimuth=None
        self.elevation=None
        # lowest elevation (degrees) to ignore any outlier source
        self.low_elevation_limit=5
        # lowest separation (degrees) to forcefully include
        self.low_separation_limit=20
        # outlier (Patch) names
        self.ateam_names=None
        self.chosen_patches=None
        self.given_patches=patch_list
        # 1: only select to minimum number of sources, increase to include more
        # 2: for HBA, 3: for LBA, this will be automatically set
        self.inclusiveness=2

        self.target_skymodel='./target.sky.txt'
        self.final_skymodel='./combined.sky.txt'
        self.parset_demix='tunedemix.parset'
        self.Nparallel=Nparallel
        # RSS : sqrt(residual sum of squares)
        self.RSS=None
        self.AIC=None
        self.AIC_probs=None
        self.RSS_freq=None
        self.AIC_freq=None
        self.RSS_iter=None
        self.AIC_iter=None
        self.RSS_dof=None
        self.AIC_dof=None

    def __del__(self):
        for ms in self.extracted_ms:
           sb.run('rm -rf '+ms,shell=True)


    def select_subbands(self,mslist,subband_ids):
        sbidx=['SB{:03d}'.format(int(x)) for x in  subband_ids]
        selected_ms=[]
        for ms in mslist:
            for sb in sbidx:
                if sb in ms:
                    selected_ms.append(ms)
        return selected_ms

    def extract_dataset(self):
        self.mslist.sort()
        msname=self.mslist[0]
        # Parset for extracting and averaging
        parset_sample='extract_sample.parset'
        parset=open(parset_sample,'w+')
        # in order to overcome storage manager issues, make a 
        # copy of data, only antenna 1
        parset.write('steps=[fil]\n'
           +'fil.type=filter\n'
           +'fil.baseline=\"0,0 &&\"\n'
           +'fil.remove=True\n'
           +'msin.datacolumn=DATA\n')
        parset.close()
        # remove old files
        sb.run('rm -rf L_SB0.MS',shell=True)
        # now process first selected MS
        MS='L_SB0.MS'
        proc1=sb.Popen(self.DP3+' '+parset_sample+' msin='+msname+' msout='+MS, shell=True)
        proc1.wait()
        msname=MS

        tt=ctab.table(msname,readonly=True)
        starttime= tt[0]['TIME']
        endtime=tt[tt.nrows()-1]['TIME']
        self.N=tt.nrows()
        tt.close()
        Nms=len(self.mslist)

        # need to have at least Nf MS
        assert(Nms>=self.Nf)

        parset=open(parset_sample,'w+')
        # sample time interval
        t_start=np.random.rand()*(endtime-starttime)+starttime
        t_end=t_start+self.timesec
        t0=atime.Time(t_start/(24*60*60),format='mjd',scale='utc')
        dt=t0.to_datetime()
        str_tstart=str(dt.year)+'/'+str(dt.month)+'/'+str(dt.day)+'/'+str(dt.hour)+':'+str(dt.minute)+':'+str(dt.second)
        t0=atime.Time(t_end/(24*60*60),format='mjd',scale='utc')
        dt=t0.to_datetime()
        str_tend=str(dt.year)+'/'+str(dt.month)+'/'+str(dt.day)+'/'+str(dt.hour)+':'+str(dt.minute)+':'+str(dt.second)

        parset.write('steps=[fil,avg]\n'
           +'fil.type=filter\n'
           +'fil.baseline=[CR]S*&\n'
           +'fil.remove=True\n'
           +'avg.type=average\n'
           +'avg.timestep=1\n'
           +'avg.freqstep=1\n'
           +'msin.datacolumn=DATA\n'
           +'msin.starttime='+str_tstart+'\n'
           +'msin.endtime='+str_tend+'\n')
        parset.close()

        # process subset of MS from mslist
        submslist=list()
        submslist.append(self.mslist[0])
        aa=list(np.random.choice(np.arange(1,Nms-1),self.Nf-2,replace=False))
        aa.sort()
        for ms_id in aa:
          submslist.append(self.mslist[ms_id])
        submslist.append(self.mslist[-1])

        # remove old files
        sb.run('rm -rf L_SB*.MS',shell=True)
        # now process each of selected MS
        self.extracted_ms=list()
        for ci in range(self.Nf):
           MS='L_SB'+str(ci)+'.MS'
           proc1=sb.Popen(self.DP3+' '+parset_sample+' msin='+submslist[ci]+' msout='+MS, shell=True)
           proc1.wait()
           self.extracted_ms.append(MS)
           # also update the weights
           self.reset_weights(MS)

        # get metadata from extracted MS
        self.freqlist=np.zeros(self.Nf)
        ci=0
        for ms in self.extracted_ms:
            tf=ctab.table(ms+'/SPECTRAL_WINDOW',readonly=True)
            ch0=tf.getcol('CHAN_FREQ')
            reffreq=tf.getcol('REF_FREQUENCY')
            self.freqlist[ci]=ch0[0,0]
            tf.close()
            ci+=1

        # update inclusiveness is not set
        if np.min(self.freqlist)<90e6:
            self.inclusiveness=3
        else:
            self.inclusiveness=2

        tf=ctab.table(self.extracted_ms[0]+'/SPECTRAL_WINDOW',readonly=True)
        self.bandwidth=tf.getcol('TOTAL_BANDWIDTH')[0]
        self.nchan=tf.getcol('NUM_CHAN')[0]
        tf.close()

        # get antennas
        ant=ctab.table(self.extracted_ms[0]+'/ANTENNA',readonly=True)
        self.N=ant.nrows()
        # baselines
        self.B=self.N*(self.N-1)//2
        # Antenna location
        xyz=ant.getcol('POSITION')
        ant.close()
        # Get target coords
        field=ctab.table(self.extracted_ms[0]+'/FIELD',readonly=True)
        phase_dir=field.getcol('PHASE_DIR')
        self.ra0=phase_dir[0][0][0]
        self.dec0=phase_dir[0][0][1]
        field.close()

        # get integration time
        tt=ctab.table(self.extracted_ms[0],readonly=True)
        tt1=tt.getcol('INTERVAL')
        Tdelta=tt[0]['INTERVAL']
        t0=tt[0]['TIME']
        tt.close()
        Tslots=math.ceil(self.timesec/Tdelta)

        # epoch coordinate UTC
        self.mydm=measures()
        mypos=self.mydm.position('ITRF',str(xyz[0][0])+'m',str(xyz[0][1])+'m',str(xyz[0][2])+'m')
        mytime=self.mydm.epoch('UTC',str(t0)+'s')
        self.mydm.doframe(mytime)
        self.mydm.doframe(mypos)


    def read_skymodel(self):
        lsm=lsmtool.load(self.skymodel)
        skymodel_patches=lsm.getPatchNames()
        # only select patches already given by user
        if len(self.given_patches)>0:
            patches=list(set(self.given_patches).intersection(skymodel_patches))
        else:
            patches=skymodel_patches
        assert(len(patches)>0)
        self.ateam_ra=np.zeros(len(patches))
        self.ateam_dec=np.zeros(len(patches))
        self.K=len(patches)
        self.ateam_names=patches
        for ci in range(self.K):
            patch=patches[ci]
            lsm=lsmtool.load(self.skymodel)
            lsm.select('Patch=='+patch)
            t=lsm.table
            self.ateam_ra[ci]=np.array(t['Ra'])[0]
            self.ateam_dec[ci]=np.array(t['Dec'])[0]
        self.separation=np.zeros(self.K)
        self.azimuth=np.zeros(self.K)
        self.elevation=np.zeros(self.K)
        ra0_q=quantity(self.ra0,'rad')
        dec0_q=quantity(self.dec0,'rad')
        target=self.mydm.direction('j2000',ra0_q,dec0_q)
        for ci in range(self.K):
            mra_q=quantity(self.ateam_ra[ci],'deg')
            mdec_q=quantity(self.ateam_dec[ci],'deg')
            cluster_dir=self.mydm.direction('j2000',mra_q,mdec_q)
            sep=self.mydm.separation(target,cluster_dir)
            self.separation[ci]=sep.get_value()
            azel=self.mydm.measure(cluster_dir,'AZEL')
            self.azimuth[ci]=azel['m0']['value']/math.pi*180
            self.elevation[ci]=azel['m1']['value']/math.pi*180

    def get_target_skymodel(self,overwrite=True):
        # If target is given as input, use it
        if self.target_skymodel_in:
            self.target_skymodel=self.target_skymodel_in
        else:
          if overwrite:
            sb.run('rm -rvf '+self.target_skymodel,shell=True)
          sb.run('python3 '+self.LINC_GET_TARGET+' --Radius 5 --targetname CENTER '+self.extracted_ms[0]+' '+self.target_skymodel,shell=True)

    # convert integer to binary bits, return array of size K
    def scalar_to_kvec(self,n):
      ll=[1 if digit=='1' else 0 for digit in bin(n)[2:]]
      a=np.zeros(self.K)
      a[-len(ll):]=ll
      return a

    # merge target with outliers given by 
    def create_final_skymodel(self):
        lsm0=lsmtool.load(self.target_skymodel)
        lsm=lsmtool.load(self.skymodel)
        lsm0.concatenate(lsm,keep='all')
        lsm0.write(self.final_skymodel,clobber=True)
 
    # merge target with outliers given by 
    # category index 'index'
    def create_parset_dir(self,index):
        chosen_dirs=self.scalar_to_kvec(index)
        chosen_patches=itertools.compress(self.ateam_names,chosen_dirs)
        parset_name=str(index)+'_'+self.parset_demix
        bbsdem=open(parset_name,'w+')
        bbsdem.write('steps=[demix]\n'
             +'demix.type=demixer\n'
             +'demix.baseline=[CR]S*&\n'
             +'demix.ignoretarget=False\n'
             +'demix.targetsource=\"CENTER\"\n'
             +'demix.demixtimeresolution=10\n'
             +'demix.demixfreqresolution=50kHz\n'
             +'demix.uselbfgssolver=true\n'
             +'demix.lbfgs.historysize=10\n'
             +'demix.lbfgs.robustdof=200\n'
             +'demix.freqstep=1\n'
             +'demix.timestep=1\n'
             +'demix.instrumentmodel=instrument_'+str(index)+'\n'
             +'demix.skymodel='+self.final_skymodel+'\n'
             )

        bbsdem.write('demix.subtractsources=[')
        firstcomma=False
        if sum(chosen_dirs)>0:
          for patch in chosen_patches:
            if not firstcomma:
                firstcomma=True
            else:
                bbsdem.write(',')
            bbsdem.write('\"'+patch+'\"')
        bbsdem.write(']\n')
        bbsdem.close()
        return parset_name

    def create_parset_freq(self,channel_index,channels):
        bandwidth=self.bandwidth/self.nchan*channels[channel_index]/1e3
        parset_name=str(channel_index)+'_'+self.parset_demix
        bbsdem=open(parset_name,'w+')
        bbsdem.write('steps=[demix]\n'
             +'demix.type=demixer\n'
             +'demix.baseline=[CR]S*&\n'
             +'demix.ignoretarget=False\n'
             +'demix.targetsource=\"CENTER\"\n'
             +'demix.demixtimeresolution=10\n'
             +'demix.demixfreqresolution='+str(bandwidth)+'kHz\n'
             +'demix.uselbfgssolver=true\n'
             +'demix.lbfgs.historysize=10\n'
             +'demix.lbfgs.robustdof=200\n'
             +'demix.freqstep=1\n'
             +'demix.timestep=1\n'
             +'demix.instrumentmodel=instrument_'+str(channel_index)+'\n'
             +'demix.skymodel='+self.final_skymodel+'\n'
             )

        bbsdem.write('demix.subtractsources=[')
        firstcomma=False
        for patch in self.chosen_patches:
            if not firstcomma:
                firstcomma=True
            else:
                bbsdem.write(',')
            bbsdem.write('\"'+patch+'\"')
        bbsdem.write(']\n')
        bbsdem.close()
        return parset_name

    def create_parset_iter(self,iter_index):
        parset_name=str(iter_index)+'_'+self.parset_demix
        bbsdem=open(parset_name,'w+')
        bbsdem.write('steps=[demix]\n'
             +'demix.type=demixer\n'
             +'demix.baseline=[CR]S*&\n'
             +'demix.ignoretarget=False\n'
             +'demix.targetsource=\"CENTER\"\n'
             +'demix.demixtimeresolution=10\n'
             +'demix.demixfreqresolution='+str(self.bandwidth_to_use)+'\n'
             +'demix.uselbfgssolver=true\n'
             +'demix.lbfgs.historysize=10\n'
             +'demix.lbfgs.robustdof='+str(self.robust_dof_to_use)+'\n'
             +'demix.maxiter='+str(self.maxiter[iter_index])+'\n'
             +'demix.freqstep=1\n'
             +'demix.timestep=1\n'
             +'demix.instrumentmodel=instrument_'+str(iter_index)+'\n'
             +'demix.skymodel='+self.final_skymodel+'\n'
             )

        bbsdem.write('demix.subtractsources=[')
        firstcomma=False
        for patch in self.chosen_patches:
            if not firstcomma:
                firstcomma=True
            else:
                bbsdem.write(',')
            bbsdem.write('\"'+patch+'\"')
        bbsdem.write(']\n')
        bbsdem.close()
        return parset_name

    def create_parset_dof(self,dof_index):
        parset_name=str(dof_index)+'_'+self.parset_demix
        bbsdem=open(parset_name,'w+')
        bbsdem.write('steps=[demix]\n'
             +'demix.type=demixer\n'
             +'demix.baseline=[CR]S*&\n'
             +'demix.ignoretarget=False\n'
             +'demix.targetsource=\"CENTER\"\n'
             +'demix.demixtimeresolution=10\n'
             +'demix.demixfreqresolution='+str(self.bandwidth_to_use)+'\n'
             +'demix.uselbfgssolver=true\n'
             +'demix.lbfgs.historysize=10\n'
             +'demix.lbfgs.robustdof='+str(self.dof[dof_index])+'\n'
             +'demix.maxiter=50\n'
             +'demix.freqstep=1\n'
             +'demix.timestep=1\n'
             +'demix.instrumentmodel=instrument_'+str(dof_index)+'\n'
             +'demix.skymodel='+self.final_skymodel+'\n'
             )

        bbsdem.write('demix.subtractsources=[')
        firstcomma=False
        for patch in self.chosen_patches:
            if not firstcomma:
                firstcomma=True
            else:
                bbsdem.write(',')
            bbsdem.write('\"'+patch+'\"')
        bbsdem.write(']\n')
        bbsdem.close()
        return parset_name


    def iter_over_directions(self):
        self.RSS=np.zeros((self.Nf,2**self.K))
        shmRSS=shared_memory.SharedMemory(create=True,size=self.RSS.nbytes)
        RSSsh=np.ndarray(self.RSS.shape,dtype=self.RSS.dtype,buffer=shmRSS.buf)

        @globalize
        def process_scenario(index):
           # if any of the sources in this index has -ve elevation,
           # stop and return a higher error
           chosen_dirs=self.scalar_to_kvec(index)
           low_elevations=any(ele <= self.low_elevation_limit for ele in itertools.compress(self.elevation,chosen_dirs))
           if not low_elevations:
               parset_name=self.create_parset_dir(index)
               # copy data
               ci=0
               for ms in self.extracted_ms:
                 out_ms='out_'+str(index)+'_'+ms
                 sb.run('rm -rf '+out_ms,shell=True)
                 proc1=sb.Popen('rsync -a '+ms+'/ '+out_ms,shell=True)
                 proc1.wait()
                 demixout_ms='residual_'+out_ms
                 sb.run('rm -rf '+demixout_ms,shell=True)
                 proc1=sb.Popen(self.DP3+' '+parset_name+'  msin='+out_ms+' msout='+demixout_ms,shell=True)
                 proc1.wait()
                 res_sigmaI,res_sigmaQ,res_sigmaU,res_sigmaV=self.get_noise_var(demixout_ms)
                 print('%s %d %e %e %e %e'%(out_ms,index,res_sigmaI,res_sigmaQ,res_sigmaU,res_sigmaV))
                 RSSsh[ci,index]=res_sigmaI
                 ci+=1
           else:
               RSSsh[:,index]=self.INF

        pool=Pool(self.Nparallel)
        pool.map(process_scenario,range(1,2**self.K))
        pool.close()
        pool.join()

        self.RSS[:]=RSSsh[:]
        shmRSS.close()
        shmRSS.unlink()
        ci=0
        for ms in self.extracted_ms:
           sigmaI,sigmaQ,sigmaU,sigmaV=self.get_noise_var(ms)
           self.RSS[ci,0]=sigmaI
           ci+=1

    def iter_over_frequency(self):
        assert(len(self.chosen_patches)>0)
        # channels to consider
        channels=[2**x for x in range(int(np.ceil(np.log(self.nchan)/np.log(2)))+1)]
        self.RSS_freq=np.zeros((self.Nf,len(channels)))
        shmRSS=shared_memory.SharedMemory(create=True,size=self.RSS_freq.nbytes)
        RSSsh=np.ndarray(self.RSS_freq.shape,dtype=self.RSS_freq.dtype,buffer=shmRSS.buf)

        @globalize
        def process_scenario_freq(channel_index):
           parset_name=self.create_parset_freq(channel_index,channels)
           # copy data
           ci=0
           for ms in self.extracted_ms:
              out_ms='out_'+str(channel_index)+'_'+ms
              sb.run('rm -rf '+out_ms,shell=True)
              proc1=sb.Popen('rsync -a '+ms+'/ '+out_ms,shell=True)
              proc1.wait()
              demixout_ms='residual_'+out_ms
              sb.run('rm -rf '+demixout_ms,shell=True)
              proc1=sb.Popen(self.DP3+' '+parset_name+'  msin='+out_ms+' msout='+demixout_ms,shell=True)
              proc1.wait()
              res_sigmaI,res_sigmaQ,res_sigmaU,res_sigmaV=self.get_noise_var(demixout_ms)
              print('%s %d %e %e %e %e'%(out_ms,channel_index,res_sigmaI,res_sigmaQ,res_sigmaU,res_sigmaV))
              RSSsh[ci,channel_index]=res_sigmaI
              ci+=1

        pool=Pool(self.Nparallel)
        pool.map(process_scenario_freq,range(len(channels)))
        pool.close()
        pool.join()

        self.RSS_freq[:]=RSSsh[:]
        shmRSS.close()
        shmRSS.unlink()

    def iter_over_maxiter(self):
        assert(len(self.chosen_patches)>0)
        assert(len(self.maxiter)>0)
        self.RSS_iter=np.zeros((self.Nf,len(self.maxiter)))
        shmRSS=shared_memory.SharedMemory(create=True,size=self.RSS_iter.nbytes)
        RSSsh=np.ndarray(self.RSS_iter.shape,dtype=self.RSS_iter.dtype,buffer=shmRSS.buf)

        @globalize
        def process_scenario_iter(iter_index):
           parset_name=self.create_parset_iter(iter_index)
           # copy data
           ci=0
           for ms in self.extracted_ms:
              out_ms='out_'+str(iter_index)+'_'+ms
              sb.run('rm -rf '+out_ms,shell=True)
              proc1=sb.Popen('rsync -a '+ms+'/ '+out_ms,shell=True)
              proc1.wait()
              demixout_ms='residual_'+out_ms
              sb.run('rm -rf '+demixout_ms,shell=True)
              proc1=sb.Popen(self.DP3+' '+parset_name+'  msin='+out_ms+' msout='+demixout_ms,shell=True)
              proc1.wait()
              res_sigmaI,res_sigmaQ,res_sigmaU,res_sigmaV=self.get_noise_var(demixout_ms)
              print('%s %d %e %e %e %e'%(out_ms,iter_index,res_sigmaI,res_sigmaQ,res_sigmaU,res_sigmaV))
              RSSsh[ci,iter_index]=res_sigmaI
              ci+=1

        pool=Pool(self.Nparallel)
        pool.map(process_scenario_iter,range(len(self.maxiter)))
        pool.close()
        pool.join()

        self.RSS_iter[:]=RSSsh[:]
        shmRSS.close()
        shmRSS.unlink()


    def iter_over_robust_dof(self):
        assert(len(self.chosen_patches)>0)
        assert(len(self.dof)>0)
        self.RSS_dof=np.zeros((self.Nf,len(self.dof)))
        shmRSS=shared_memory.SharedMemory(create=True,size=self.RSS_dof.nbytes)
        RSSsh=np.ndarray(self.RSS_dof.shape,dtype=self.RSS_dof.dtype,buffer=shmRSS.buf)
        @globalize
        def process_scenario_dof(iter_index):
           parset_name=self.create_parset_dof(iter_index)
           # copy data
           ci=0
           for ms in self.extracted_ms:
              out_ms='out_'+str(iter_index)+'_'+ms
              sb.run('rm -rf '+out_ms,shell=True)
              proc1=sb.Popen('rsync -a '+ms+'/ '+out_ms,shell=True)
              proc1.wait()
              demixout_ms='residual_'+out_ms
              sb.run('rm -rf '+demixout_ms,shell=True)
              proc1=sb.Popen(self.DP3+' '+parset_name+'  msin='+out_ms+' msout='+demixout_ms,shell=True)
              proc1.wait()
              res_sigmaI,res_sigmaQ,res_sigmaU,res_sigmaV=self.get_noise_var(demixout_ms)
              print('%s %d %e %e %e %e'%(out_ms,iter_index,res_sigmaI,res_sigmaQ,res_sigmaU,res_sigmaV))
              RSSsh[ci,iter_index]=res_sigmaI
              ci+=1

        pool=Pool(self.Nparallel)
        pool.map(process_scenario_dof,range(len(self.dof)))
        pool.close()
        pool.join()

        self.RSS_dof[:]=RSSsh[:]
        shmRSS.close()
        shmRSS.unlink()

    # reset weights to 1, after applying weight to datum
    def reset_weights(self,msname):
        tt=ctab.table(msname,readonly=False)
        data=tt.getcol('DATA')
        if 'WEIGHT_SPECTRUM' in tt.colnames():
          weight=tt.getcol('WEIGHT_SPECTRUM')
        else:
          weight=tt.getcol('IMAGING_WEIGHT')
        # weight ~sigma^2, std(weight) ~ sigma^2, so sqrt()
        weight_std=np.sqrt(weight.std())
        data *=weight_std
        weight /=weight_std
        tt.putcol('DATA',data)
        tt.putcol('WEIGHT_SPECTRUM',weight)
        tt.close()

    # extract noise info
    def get_noise_var(self,msname):
        tt=ctab.table(msname,readonly=True)
        t1=tt.query('ANTENNA1 != ANTENNA2',columns='DATA,FLAG')
        data0=t1.getcol('DATA')
        # set nans to 0
        flag=t1.getcol('FLAG')
        flag[np.isnan(flag)]=1
        data=data0*(1-flag)
        tt.close()
        # set nans to 0
        data[np.isnan(data)]=0.
        # form IQUV
        sI=(data[:,:,0]+data[:,:,3])*0.5
        sQ=(data[:,:,0]-data[:,:,3])*0.5
        sU=(data[:,:,1]-data[:,:,2])*0.5
        sV=(data[:,:,1]+data[:,:,2])*0.5
        return sI.std(),sQ.std(),sU.std(),sV.std()

    # calculate AIC = RSS * N_stat *N_stat + K N_stat
    # AIC = (noise RSS/ data RSS) * N_stat *N_stat + K * N_stat
    def calc_AIC(self):
        self.AIC=np.zeros(2**self.K)
        for index in range(2**self.K):
          chosen_dirs=self.scalar_to_kvec(index)
          rss=(np.mean(self.RSS[:,index])/np.mean(self.RSS[:,0]))**2
          self.AIC[index]=rss*self.N*self.N+np.sum(chosen_dirs)*self.N

        indx=np.argsort(self.AIC[1:])+1
        probs=np.exp(-self.AIC/self.AIC[indx[0]])/np.sum(np.exp(-self.AIC/self.AIC[indx[0]]))
        self.AIC_probs=np.zeros(self.K)
        for ci in range(2**self.K):
            self.AIC_probs+=probs[ci]*self.scalar_to_kvec(ci)
        AIC_probs=self.AIC_probs.copy()
        chosen_dirs=np.zeros(self.K)
        for index in range(self.inclusiveness):
            max_id=np.argmax(AIC_probs)
            chosen_dirs[max_id]=1
            AIC_probs[max_id]=0

        # remove -ve elevation dirs, again, to be safe
        chosen_dirs[self.elevation<=self.low_elevation_limit]=0
        # also add sources that are close
        close_sources=[index for index in range(self.K) if self.separation[index]<=self.low_separation_limit]
        chosen_dirs[self.separation<=self.low_separation_limit]=1
        chosen_patches=list(itertools.compress(self.ateam_names,chosen_dirs))
        chosen_separation=list(itertools.compress(self.separation,chosen_dirs))
        sorted_patches=[patch for _,patch in sorted(zip(chosen_separation,chosen_patches))]
        if len(chosen_patches)>self.inclusiveness and len(close_sources)==0:
            self.chosen_patches=sorted_patches[:-1]
        else:
            self.chosen_patches=sorted_patches

    def calc_bandwidth_to_use(self):
        channels=[2**x for x in range(int(np.ceil(np.log(self.nchan)/np.log(2)))+1)]
        for ci in range(self.Nf):
            self.RSS_freq[ci,:]/=self.RSS[ci,0]
        rss=np.mean(self.RSS_freq,axis=0)**2
        self.AIC_freq=rss*self.N*self.N+self.nchan/channels*self.N*len(self.chosen_patches)
        index=np.argmin(self.AIC_freq)
        self.bandwidth_to_use=self.bandwidth*channels[index]/self.nchan


    def calc_maxiter_to_use(self):
        for ci in range(self.Nf):
            self.RSS_iter[ci,:]/=self.RSS[ci,0]
        rss=np.mean(self.RSS_iter,axis=0)**2
        self.AIC_iter=rss*self.N*self.N+self.N*len(self.chosen_patches)*np.log(np.array(self.maxiter))
        index=np.argmin(self.AIC_iter)
        self.maxiter_to_use=self.maxiter[index]

    def calc_dof_to_use(self):
        for ci in range(self.Nf):
            self.RSS_dof[ci,:]/=self.RSS[ci,0]
        rss=np.mean(self.RSS_dof,axis=0)**2
        self.AIC_dof=rss*self.N*self.N+np.array(self.dof)*self.N
        index=np.argmin(self.AIC_dof)
        self.robust_dof_to_use=self.dof[index]

    def report(self):
        info={}
        print('*************************')
        print('Patches')
        print(self.ateam_names)
        info['patches']=self.ateam_names
        print('Probabilities of being selected')
        print(self.AIC_probs)
        info['probabilities']=self.AIC_probs
        print('Separation')
        print(self.separation)
        info['separation']=self.separation
        print('Elevation')
        print(self.elevation)
        info['elevation']=self.elevation
        print('Direction selection')
        print(self.RSS)
        info['RSS_dir']=self.RSS
        print(self.AIC)
        info['AIC_dir']=self.AIC
        print('Frequency resolution selection')
        print(self.RSS_freq)
        info['RSS_freq']=self.RSS_freq
        print(self.AIC_freq)
        info['AIC_freq']=self.AIC_freq
        print('Robust DOF selection')
        print(self.RSS_dof)
        info['RSS_dof']=self.RSS_dof
        print(self.AIC_dof)
        info['AIC_dof']=self.AIC_dof
        print('Maxiter selection')
        print(self.RSS_iter)
        info['RSS_iter']=self.RSS_iter
        print(self.AIC_iter)
        info['AIC_iter']=self.AIC_iter
        print('Expected reduction in noise is %f %%'%((1-np.mean(self.RSS_iter[:,np.argmin(self.AIC_iter)]))*100.0))
        print('*************************')
        return info

    def cleanup(self):
       for ms in self.extracted_ms:
         out_ms='out_*_'+ms
         sb.run('rm -rf '+out_ms,shell=True)
         demixout_ms='residual_'+out_ms
         sb.run('rm -rf '+demixout_ms,shell=True)
       
       parset_name='*_'+self.parset_demix
       sb.run('rm -rf '+parset_name,shell=True)
       sb.run('rm -rf instrument_*',shell=True)

    def run(self):
        # Note: the order of following being run is important
        self.extract_dataset()
        self.read_skymodel()
        self.get_target_skymodel(overwrite=True)
        self.create_final_skymodel()
        self.iter_over_directions()
        self.cleanup()
        self.calc_AIC()
        self.iter_over_frequency()
        self.calc_bandwidth_to_use()
        self.cleanup()
        self.iter_over_robust_dof()
        self.calc_dof_to_use()
        self.cleanup()
        self.iter_over_maxiter()
        self.calc_maxiter_to_use()
        self.cleanup()
        info=self.report()
        return info

def main(args):
   tuner=DemixingTuner(args.MS,args.sky_model,args.target_sky_model,args.time_interval,args.subbands,args.subband_list,args.parallel_jobs,path_DP3=args.DP3, path_LINC=args.LINC, patch_list=args.patches)
   info=tuner.run()

   result={}
   result['maxiter']=tuner.maxiter_to_use
   result['demixfreqresolution']=tuner.bandwidth_to_use
   result['subtractsources']=tuner.chosen_patches
   result['lbfgs.robustdof']=tuner.robust_dof_to_use
   result['demixtimeresolution']=10
   result['lbfgs.historysize']=10

   return result,info


if __name__=='__main__':
    parser=argparse.ArgumentParser(
      description='Tune demixing parameters',
      formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    parser.add_argument('--MS',type=str,metavar='\'*.MS\'',
        help='absolute path of MS pattern to use')
    parser.add_argument('--sky_model',type=str,metavar='s',
        help='A-Team sky model (text)')
    parser.add_argument('--target_sky_model',type=str,default=None,metavar='st',
        help='target sky model (text)')
    parser.add_argument('--subbands',default=3,type=int,metavar='f',
        help='number of randomly selected subbands to process (minimum 3)')
    parser.add_argument('--subband_list',nargs='+',default=[],
        type=int,metavar='sb',
        help='list of subband numbers [0,1,..] to process (will override number of subbands)')
    parser.add_argument('--time_interval',type=float,default=30.,metavar='t',
        help='total time interval to sample in seconds')
    parser.add_argument('--parallel_jobs',type=int,default=4,metavar='p',
        help='number of parallel jobs')
    parser.add_argument('--patches',nargs='+',default=[],
        metavar='Patch1',
        help='names of patches to subtract if found in sky model, if None, all patches')
    parser.add_argument('--DP3',type=str,metavar='DP3',default=default_DP3,
        help='DP3 command')
    parser.add_argument('--LINC',type=str,metavar='LINC',default=default_LINC_GET_TARGET,
        help='path to download_skymodel_target.py script (not used if target sky model is given)')


    args=parser.parse_args()
    format_stream = logging.Formatter("%(asctime)s\033[1m %(levelname)s:\033[0m %(message)s","%Y-%m-%d %H:%M:%S")
    format_file   = logging.Formatter("%(asctime)s %(levelname)s: %(message)s","%Y-%m-%d %H:%M:%S")
    logging.root.setLevel(logging.INFO)

    log = logging.StreamHandler()
    log.setFormatter(format_stream)
    logging.root.addHandler(log)

    if args.MS and args.sky_model:
      res,info=main(args)
      print(res)
    else:
      parser.print_help()