#!/usr/bin/env python3

from os.path import basename
from glob import glob
import sys
import os
import casacore.tables as pt
import numpy as np

import subprocess

def extract_median_autocorrelations(input_directory):
    """
    Extract median IQUV of autocorrelations from all measurement sets saved in input_directory.
    Saves in a casacore table in a new subdirectory of the working directory.

    Returns the name of the subdirectory.
    """
    input_ms_list = sorted(glob(f"{input_directory}/*.MS"))
    first_ms = input_ms_list[0]
    obsname = basename(first_ms).split("_")[0]

    os.mkdir(obsname + "_tmp")

    taql_processes = []
    for input_ms in input_ms_list:
        output_table = obsname + "_tmp" + "/" + basename(input_ms).replace(".MS", ".tab")
        taql_cmd = f"select medians(abs(mscal.stokes(DATA, 'IQUV')), 0) as IQUV from {input_ms} where ANTENNA1==ANTENNA2 giving {output_table} as plain"
        taql_processes.append(subprocess.Popen(["taql", "-m", "0", "-nopr", taql_cmd]))

    for taql_process in taql_processes:
        taql_process.wait()

    print("All taql done")
    return obsname + "_tmp"

def combine_bands(tmp_directory, input_directory):
    input_ms_list = sorted(glob(f"{input_directory}/*.MS"))
    tmp_table_list = sorted(glob(f"{tmp_directory}/*.tab"))
    first_ms = input_ms_list[0]

    assert len(input_ms_list) == len(tmp_table_list)

    num_times = len(pt.table(tmp_table_list[0], ack=False))
    num_subbands = len(tmp_table_list)

    total_data = np.zeros([num_subbands, num_times, 4])

    for i, tmp_table in enumerate(tmp_table_list):
        iquv = pt.table(tmp_table, ack=False).getcol("IQUV")
        total_data[i] = iquv

if __name__ == "__main__":
    tmp_dir = extract_median_autocorrelations(sys.argv[1])
    combine_bands(tmp_dir, sys.argv[1])