Skip to content
Snippets Groups Projects
mshologextract.py 4.86 KiB
Newer Older
from classes import *
from utils import *
import numpy
import argparse
from casacore.tables import table as MS_Table


def main():
    cla_parser = specify_command_line_arguments()
    arguments = parse_command_line_arguments(cla_parser)
    bsfs, mss = read_holography_datasets(arguments.input_path)

    target_station = bsfs[0].target_station_names[0]
    reference_station = bsfs[0].reference_station_names[0]

    write_to_numpy_array(bsfs[0].station_specification_map[target_station],
                         target_station,
                         reference_station,
                         mss.items()[0][1], 'dummy')

    write_to_numpy_array(bsfs[0].station_specification_map[target_station],
                         target_station,
                         reference_station,
                         mss.items()[0][1], 'dummy2')

def parse_command_line_arguments(parser):
    return parser.parse_args()


def specify_command_line_arguments():
    parser = argparse.ArgumentParser(description='This program is meant for convert the Holography'
                                                 ' observation\'s data into an holography dataset')
    parser.add_argument('input_path', help='path to the holography observation data')
    parser.add_argument('--holography_bsf', help='override default path for the holography beam'
                                                 ' specification file')
    parser.add_argument('--holography_ms', help='override default path for the holography'
                                                ' observation MS files')
    return parser


def read_holography_datasets(holography_observation_path,
                             holography_bsf_path=None,
                             holography_ms_path=None):

    if holography_ms_path is not None:
        raise NotImplementedError()
    if holography_bsf_path is not None:
        raise NotImplementedError()

    list_of_bsf_files = list_bsf_files_in_path(holography_observation_path)
    list_of_ms = list_ms_files_in_path(holography_observation_path)
    for bsf in list_of_bsf_files:
        bsf.read_file()

    return list_of_bsf_files, list_of_ms


def write_to_numpy_array(bsf_specifications, target, reference, ms_per_sub_band, location):
    ra_dec_list, crosscorrelation_list,\
    timestamp_list = extract_crosscorrelation_per_ra_dec_timestamp(bsf_specifications, target,
                                                                   reference,
                                                                   ms_per_sub_band)
    ra_dec, timestamp, matrix = create_crosscorrelation_matrix(ra_dec_list,
                                                               crosscorrelation_list,
                                                               timestamp_list)
    numpy.savez(location, ra_dec=numpy.array(list(ra_dec)), timestamp=timestamp, crosscorrelation=matrix)


def extract_crosscorrelation_per_ra_dec_timestamp(bsf_specifications,
                                                  target,
                                                  reference,
                                                  ms_per_sub_band):
    ra_dec_list = []
    crosscorrelation_list = []
    timestamp_list = []
    for bsf_specification in bsf_specifications:
        beamlet = int(bsf_specification['beamlets'])

        ms = ms_per_sub_band[beamlet]
        timestamp, crosscorrelation = ms.read_cross_correlation_per_station_names(target,
                                                                                  reference)
        virtual_pointing = bsf_specification['virtual_pointing']
        ra = virtual_pointing['ra']
        dec = virtual_pointing['dec']
        ra_dec_list += [(ra, dec)]

        crosscorrelation_list += list(crosscorrelation)
        timestamp_list += list(timestamp)
    return ra_dec_list, numpy.array(crosscorrelation_list), timestamp_list


def create_crosscorrelation_matrix(ra_dec_list, crosscorrelation_list, timestamp_list):
    unique_timestamp = numpy.unique(timestamp_list)
    unique_ra_dec = set(ra_dec_list)
    polarizations = crosscorrelation_list[0].shape[1]
    print('polarizations', polarizations)
    number_of_pointings = len(ra_dec_list)
    number_of_timestamps = len(unique_timestamp)
    shape = (number_of_pointings, number_of_timestamps, polarizations)
    print('shape', shape)
    cross_correlation_matrix = numpy.zeros(shape, dtype=numpy.complex64)
    for time_index, timestamp in enumerate(unique_timestamp):
        selected_indexes = list(numpy.where(timestamp_list == timestamp)[0])
        print(selected_indexes)

        selected_crosscorrelations = crosscorrelation_list[selected_indexes]
        for ra_dec_index, cross_correlation in enumerate(selected_crosscorrelations):
            cross_correlation_matrix[ra_dec_index, time_index, :] = cross_correlation

    return unique_ra_dec, unique_timestamp, cross_correlation_matrix



if __name__ == '__main__':
    main()