From dbd9aa08903f7fd2c22c45266f53a2fa4fe9156a Mon Sep 17 00:00:00 2001
From: Sarod Yatawatta <yatawatta@astron.nl>
Date: Wed, 5 Feb 2025 10:36:55 +0000
Subject: [PATCH] Fix plot demix

---
 scripts/plot_demixing_solutions.py         | 46 ++++++++++++++++++----
 steps/dp3_make_parset_cal.cwl              |  2 +-
 steps/dp3_make_parset_target.cwl           |  2 +-
 steps/dp3_prep_cal.cwl                     |  5 +++
 steps/dp3_prep_target.cwl                  |  5 +++
 steps/plot_demix.cwl                       | 36 +++++++++++++++++
 workflows/linc_calibrator.cwl              |  4 +-
 workflows/linc_calibrator/dp3_prep_cal.cwl |  5 +++
 workflows/linc_calibrator/prep.cwl         | 37 ++++++++++++++++-
 workflows/linc_target/dp3_prep_targ.cwl    |  5 +++
 workflows/linc_target/prep.cwl             | 30 ++++++++++++++
 11 files changed, 163 insertions(+), 14 deletions(-)
 create mode 100644 steps/plot_demix.cwl

diff --git a/scripts/plot_demixing_solutions.py b/scripts/plot_demixing_solutions.py
index 71d6aa88..78c7afd9 100755
--- a/scripts/plot_demixing_solutions.py
+++ b/scripts/plot_demixing_solutions.py
@@ -16,6 +16,23 @@ import matplotlib as mpl
 mpl.use('Agg')
 import matplotlib.pyplot as plt
 
+def input2strlist_nomapfile(invar):
+   """ 
+   from bin/download_IONEX.py
+   give the list of MSs from the list provided as a string
+   """
+   str_list = None
+   if type(invar) is str:
+       if invar.startswith('[') and invar.endswith(']'):
+           str_list = [f.strip(' \'\"') for f in invar.strip('[]').split(',')]
+       else:
+           str_list = [invar.strip(' \'\"')]
+   elif type(invar) is list:
+       str_list = [str(f).strip(' \'\"') for f in invar]
+   else:
+       raise TypeError('input2strlist: Type '+str(type(invar))+' unknown!')
+   return str_list
+
 def globalize(func):
   def result(*args, **kwargs):
     return func(*args, **kwargs)
@@ -26,7 +43,7 @@ def globalize(func):
 
 class PlotGenerator:
     def __init__(self,in_soltables,clipval):
-        self.in_soltab=sorted(glob.glob(in_soltables))
+        self.in_soltab=in_soltables
         self.clipval=clipval
         # initialize values
         self.N=0
@@ -35,6 +52,7 @@ class PlotGenerator:
         self.T=0
         self.F=0
         self.directions=None
+        self.stations=None
 
         self.Nparallel=4
         self.Jnorm=None
@@ -50,10 +68,11 @@ class PlotGenerator:
         self.F=vl.shape[2]
         sol_names_tab=ctab.table(tt.getkeyword('NAMES'),readonly=True)
         sol_names=sol_names_tab.getcol('NAME')
-        self.directions=self.get_directions(sol_names)
+        self.directions,self.stations=self.get_directions(sol_names)
         self.directions.reverse()
         self.K=len(self.directions)
         self.N=n_prod//(8*self.K)
+        assert(self.N==len(self.stations))
         print(f'Processing {self.B} subbands, {self.N} stations, {self.K} directions, time {self.T} freq {self.F}')
         tt.close()
 
@@ -105,33 +124,44 @@ class PlotGenerator:
             for cj in range(ncols):
                 ck=ci*ncols+cj
                 if ck<self.N:
-                  im=axs[ci,cj].imshow(self.Jnorm[ndir,ck],interpolation=None,aspect='auto')
+                  if ncols > 1 :
+                    im=axs[ci,cj].imshow(self.Jnorm[ndir,ck],interpolation=None,aspect='auto')
+                    axs[ci,cj].text(0,0,self.stations[ck])
+                  else :
+                    im=axs[ci].imshow(self.Jnorm[ndir,ck],interpolation=None,aspect='auto')
+                    axs[ci].text(0,0,self.stations[ck])
                 else:
-                  axs[ci,cj].axis('off')
+                  if ncols > 1 :
+                    axs[ci,cj].axis('off')
+                  else:
+                    axs[ci].axis('off')
 
           for ax in fig.get_axes():
             ax.label_outer()
           cb_ax=fig.add_axes([0.9, 0.1, 0.02, 0.8])
           cbar=fig.colorbar(im,cax=cb_ax)
+          cbar.set_label('solution norm')
           fig.suptitle('direction '+str(self.directions[ndir]))
           fig.supxlabel('subbands')
           fig.supylabel('time')
-          plt.savefig(os.path.basename(self.in_soltab[0])+'_dir_'+str(ndir)+'.png')
+          plt.savefig(os.path.basename(self.in_soltab[0])+'_'+str(self.directions[ndir])+'.png')
           plt.close()
 
     def get_directions(self,solution_names):
         """
          solution_names will include all solutions,
          like 'DirectionGain:x:y:Real/Imag:Station:Direction'
-         parse this and find unique direction names
+         parse this and find unique direction names and unique station names
         """
         sourcenames=[]
+        stationnames=[]
         for solname in solution_names:
            sourcenames.append(solname.split(':')[-1])
-        return list(set(sourcenames))
+           stationnames.append(solname.split(':')[-2])
+        return list(set(sourcenames)),list(set(stationnames))
 
 def main(args):
-    pg=PlotGenerator(args.instrument_tables,args.clip)
+    pg=PlotGenerator(input2strlist_nomapfile(args.instrument_tables),args.clip)
     pg.read_solutions()
     pg.plot_solutions()
 
diff --git a/steps/dp3_make_parset_cal.cwl b/steps/dp3_make_parset_cal.cwl
index 17641729..2eb6fd53 100755
--- a/steps/dp3_make_parset_cal.cwl
+++ b/steps/dp3_make_parset_cal.cwl
@@ -113,7 +113,7 @@ requirements:
           demix.ntimechunk                    =   $(inputs.ntimechunk)
           demix.freqstep                      =   1
           demix.timestep                      =   1
-          demix.instrumentmodel               =   instrument
+          demix.instrumentmodel               =   demix_solutions
           demix.uselbfgssolver                =   True
           demix.lbfgs.historysize             =   $(inputs.lbfgs_historysize)
           demix.lbfgs.robustdof               =   $(inputs.lbfgs_robustdof)
diff --git a/steps/dp3_make_parset_target.cwl b/steps/dp3_make_parset_target.cwl
index 52179bce..29d422e9 100755
--- a/steps/dp3_make_parset_target.cwl
+++ b/steps/dp3_make_parset_target.cwl
@@ -167,7 +167,7 @@ requirements:
           demix.ntimechunk                    =   $(inputs.ntimechunk)
           demix.freqstep                      =   1
           demix.timestep                      =   1
-          demix.instrumentmodel               =   instrument
+          demix.instrumentmodel               =   demix_solutions
           demix.uselbfgssolver                =   True
           demix.lbfgs.historysize             =   $(inputs.lbfgs_historysize)
           demix.lbfgs.robustdof               =   $(inputs.lbfgs_robustdof)
diff --git a/steps/dp3_prep_cal.cwl b/steps/dp3_prep_cal.cwl
index 98cc6c96..1d2a09bf 100755
--- a/steps/dp3_prep_cal.cwl
+++ b/steps/dp3_prep_cal.cwl
@@ -104,6 +104,11 @@ outputs:
     type: Directory
     outputBinding:
       glob: '$(inputs.msout_name=="." ? inputs.msin.basename : inputs.msout_name)'
+  - id: instrument_table
+    doc: Output Measurement Set
+    type: Directory?
+    outputBinding:
+      glob: 'demix_solutions'
   - id: flagged_fraction_dict
     type: string
     outputBinding:
diff --git a/steps/dp3_prep_target.cwl b/steps/dp3_prep_target.cwl
index 3d080d08..beb20815 100755
--- a/steps/dp3_prep_target.cwl
+++ b/steps/dp3_prep_target.cwl
@@ -154,6 +154,11 @@ outputs:
     type: Directory
     outputBinding:
       glob: '$(inputs.msout_name=="." ? inputs.msin.basename : inputs.msout_name)'
+  - id: instrument_table
+    doc: Output Measurement Set
+    type: Directory?
+    outputBinding:
+      glob: 'demix_solutions'
   - id: flagged_fraction_dict_initial
     type: string
     outputBinding:
diff --git a/steps/plot_demix.cwl b/steps/plot_demix.cwl
new file mode 100644
index 00000000..efd9a3b6
--- /dev/null
+++ b/steps/plot_demix.cwl
@@ -0,0 +1,36 @@
+class: CommandLineTool
+cwlVersion: v1.2
+id: plot_demix
+baseCommand:
+  - plot_demixing_solutions.py
+inputs:
+  - id: instrument_tables
+    type: Directory[]?
+    inputBinding:
+      position: 0
+      prefix: '--instrument_tables'
+      itemSeparator: ','
+      valueFrom: "[$(self.map(function(directory){ return directory.path; }).join(','))]"
+  - id: clip_solutions
+    type: float?
+
+label: plot_demix_solutions
+
+outputs:
+  - id: demix_images
+    doc: Output image
+    type: File[]
+    outputBinding:
+      glob: 'demix_solutions_*.png'
+  - id: logfile
+    type: File[]
+    outputBinding:
+      glob: 'plot_demix_solutions*.log'
+
+hints:
+  - class: DockerRequirement
+    dockerPull: 'astronrd/linc'
+requirements:
+  - class: InlineJavascriptRequirement
+stdout: plot_demix_solutions.log
+stderr: plot_demix_solutions_err.log
diff --git a/workflows/linc_calibrator.cwl b/workflows/linc_calibrator.cwl
index 2698eb3f..481523a1 100644
--- a/workflows/linc_calibrator.cwl
+++ b/workflows/linc_calibrator.cwl
@@ -138,7 +138,7 @@ outputs:
   - id: inspection
     linkMerge: merge_flattened
     outputSource:
-      - prep/check_Ateam_separation.png
+      - prep/inspection
       - pa/inspection
       - fr/inspection
       - bp/inspection
@@ -235,7 +235,7 @@ steps:
       - id: outh5parm
       - id: logfiles
       - id: outh5parm_logfile
-      - id: check_Ateam_separation.png
+      - id: inspection
       - id: check_Ateam_separation.json
       - id: flagged_fraction_dict_refant
       - id: flagged_fraction_dict
diff --git a/workflows/linc_calibrator/dp3_prep_cal.cwl b/workflows/linc_calibrator/dp3_prep_cal.cwl
index 5c1f5bf5..c1e7c4fc 100644
--- a/workflows/linc_calibrator/dp3_prep_cal.cwl
+++ b/workflows/linc_calibrator/dp3_prep_cal.cwl
@@ -76,6 +76,10 @@ outputs:
     outputSource:
       - dp3_execute/flagged_fraction_dict
     type: string
+  - id: instrument_tables
+    outputSource:
+      - dp3_execute/instrument_table
+    type: Directory?
 steps:
   - id: define_parset
     in:
@@ -146,6 +150,7 @@ steps:
         source: skymodel
     out:
       - id: msout
+      - id: instrument_table
       - id: flagged_fraction_dict
       - id: logfile
     run: ../../steps/dp3_prep_cal.cwl
diff --git a/workflows/linc_calibrator/prep.cwl b/workflows/linc_calibrator/prep.cwl
index 6f18f41c..c7b1cb9d 100644
--- a/workflows/linc_calibrator/prep.cwl
+++ b/workflows/linc_calibrator/prep.cwl
@@ -88,10 +88,13 @@ outputs:
     outputSource:
       - h5parm_collector/outh5parm
     type: File
-  - id: check_Ateam_separation.png
+  - id: inspection
     outputSource:
       - check_ateam_separation/output_imag
-    type: File
+      - plot_demix/demix_images
+    type: File[]
+    linkMerge: merge_flattened
+    pickValue: all_non_null
   - id: msout
     outputSource: predict_calibrate/msout
     type: Directory[]
@@ -126,8 +129,10 @@ outputs:
       - check_ateam_separation/logfile
       - check_demix/logfile
       - aoflag/logfile
+      - concat_logfiles_plot_demix/output
     type: File[]
     linkMerge: merge_flattened
+    pickValue: all_non_null
 steps:
   - id: select
     in:
@@ -193,6 +198,7 @@ steps:
         source: lbfgs_robustdof
     out:
       - id: msout
+      - id: instrument_tables
       - id: flagged_fraction_dict
       - id: logfile
     run: ./dp3_prep_cal.cwl
@@ -277,6 +283,18 @@ steps:
       - id: logfile
     run: ../../steps/check_demix.cwl
     label: check_demix
+  - id: plot_demix
+    in:
+      - id: instrument_tables
+        source: dp3_prep_cal/instrument_tables
+      - id: demix
+        source: check_demix/out_demix
+    out:
+      - id: demix_images
+      - id: logfile
+    run: ../../steps/plot_demix.cwl
+    label: plot_demix_solutions
+    when: $(inputs.demix)
   - id: find_skymodel_cal
     in:
       - id: msin
@@ -361,6 +379,21 @@ steps:
       - id: output
     run: ../../steps/concatenate_files.cwl
     label: concat_logfiles_calib
+  - id: concat_logfiles_plot_demix
+    in:
+      - id: file_list
+        source:
+          - plot_demix/logfile
+        pickValue: all_non_null
+      - id: file_prefix
+        default: plot_demix_solutions
+      - id: demix
+        source: check_demix/out_demix
+    out:
+      - id: output
+    run: ../../steps/concatenate_files.cwl
+    label: concat_logfiles_plot_demix
+    when: $(inputs.demix)
   - id: h5parm_collector
     in:
       - id: h5parmFiles
diff --git a/workflows/linc_target/dp3_prep_targ.cwl b/workflows/linc_target/dp3_prep_targ.cwl
index 778520f7..8a61bba6 100644
--- a/workflows/linc_target/dp3_prep_targ.cwl
+++ b/workflows/linc_target/dp3_prep_targ.cwl
@@ -126,6 +126,10 @@ outputs:
       - Ateamclipper/output
     type: File
     pickValue: all_non_null
+  - id: instrument_tables
+    outputSource:
+      - dp3_execute/instrument_table
+    type: Directory?
 steps:
   - id: define_parset
     in:
@@ -253,6 +257,7 @@ steps:
         source: skymodel
     out:
       - id: msout
+      - id: instrument_table
       - id: flagged_fraction_dict_initial
       - id: flagged_fraction_dict_prep
       - id: logfile
diff --git a/workflows/linc_target/prep.cwl b/workflows/linc_target/prep.cwl
index 72ba3799..abcd945c 100644
--- a/workflows/linc_target/prep.cwl
+++ b/workflows/linc_target/prep.cwl
@@ -145,6 +145,7 @@ outputs:
       - check_ateam_separation/output_imag
       - losoto_plot_RM/output_plots
       - plot_Ateamclipper/output_imag
+      - plot_demix/demix_images
     type: File[]
     linkMerge: merge_flattened
     pickValue: all_non_null
@@ -168,6 +169,7 @@ outputs:
       - concat_logfiles_prep_targ/output
       - concat_logfiles_predict_targ/output
       - concat_logfiles_clipper_targ/output
+      - concat_logfiles_plot_demix/output
     type: File[]
     linkMerge: merge_flattened
     pickValue: all_non_null
@@ -478,6 +480,7 @@ steps:
       - id: clipper_logfile
       - id: msout
       - id: clipper_output
+      - id: instrument_tables
     run: ./dp3_prep_targ.cwl
     label: dp3_prep_target
     scatter:
@@ -493,6 +496,18 @@ steps:
     run: ../../steps/plot_Ateamclipper.cwl
     when: $(inputs.execute)
     label: concat_logfiles_clipper_output
+  - id: plot_demix
+    in:
+      - id: instrument_tables
+        source: dp3_prep_target/instrument_tables
+      - id: demix
+        source: check_demix/out_demix
+    out:
+      - id: demix_images
+      - id: logfile
+    run: ../../steps/plot_demix.cwl
+    label: plot_demix_solutions
+    when: $(inputs.demix)
   - id: concat_logfiles_clipper_output
     in:
       - id: file_list
@@ -590,6 +605,21 @@ steps:
       - id: output
     run: ../../steps/concatenate_files.cwl
     label: concat_logfiles_skymodels
+  - id: concat_logfiles_plot_demix
+    in:
+      - id: file_list
+        source:
+          - plot_demix/logfile
+        pickValue: all_non_null
+      - id: file_prefix
+        default: plot_demix_solutions
+      - id: demix
+        source: check_demix/out_demix
+    out:
+      - id: output
+    run: ../../steps/concatenate_files.cwl
+    label: concat_logfiles_plot_demix
+    when: $(inputs.demix)
 requirements:
   - class: SubworkflowFeatureRequirement
   - class: ScatterFeatureRequirement
-- 
GitLab