Skip to content
Snippets Groups Projects
classes.py 11.3 KiB
Newer Older
import re
import os
from collections import defaultdict
import datetime
from casacore.tables import table as MS_Table


class HolographySpecification(object):
    hs_name_pattern = r'Holog-(?P<date>\d{8})-(?P<comment>.*)-(?P<id>\d{3}).txt'

    def __init__(self, name, path):
        self.path = os.path.join(path, name)
        self.name = name
        self.id, self.date, self.comment = HolographySpecification.\
            extract_id_date_comment_from_name(name)

        self.station_specification_map = defaultdict(list)
        self.start_datetime = None
        self.end_datetime = None
        self.rcu_mode = None
        self.beam_set_interval = None


    def __repr__(self):
        return 'HolographySpecification(%s, %s, %s, %s, %s)' % (
                                                                self.id,
                                                                self.date,
                                                                self.comment,
                                                                self.name,
                                                                self.path,
                                                                )

    def _read_lines(self):
        with open(self.path, 'r') as fstream_in:
            return fstream_in.read().splitlines()

    @staticmethod
    def _split_header(line):
        date_regex = '\d{4}-\d{2}-\d{2}\s\d{2}:\d{2}:\d{2}'
        format = r'(?P<start_date>{date_regex})\s*' \
                  '(?P<end_date>{date_regex})\s*' \
                  '(?P<rcu_mode>\d*)\s*' \
                  '(?P<beam_switch_delay>\d*.\d*)'.format(date_regex=date_regex)
        match = re.match(format, line)
        return match.groupdict()

    def _parse_header(self, header):
        splitted_header = HolographySpecification._split_header(header)
        self.start_datetime = splitted_header['start_date']
        self.end_datetime = splitted_header['end_date']
        self.rcu_mode = splitted_header['rcu_mode']
        self.beam_set_interval = splitted_header['beam_switch_delay']

    @staticmethod
    def _split_line(line):
        range_regex = '(\d*\:\d*)|(\d*)'
        ra_dec_regex = '\d*\.\d*,-?\d*\.\d*,\w*'
        regex = r'^(?P<station_name>\w*)\s*' \
                r'(?P<mode_description>\w*)\s*' \
                r'(?P<sub_band>[\d,]*)\s*' \
                r'(?P<beamlets>{range_regex})\s*' \
                r'(?P<rcus>{range_regex})\s*' \
                r'(?P<rcus_mode>(\d*))\s*' \
                r'(?P<virtual_pointing>{ra_dec_regex})\s*' \
                r'(?P<station_pointing>{ra_dec_regex})'.format(range_regex=range_regex,
                                                               ra_dec_regex=ra_dec_regex)
        match = re.match(regex, line)
        if match is None:
            raise ValueError('Cannot parse line {}'.format(line))
        return match.groupdict()

    @staticmethod
    def _split_lines(lines):
        return [HolographySpecification._split_line(line)
                for line in lines]

    @staticmethod
    def _parse_pointing(pointing_string):
        ra, dec, coordinate_system = pointing_string.split(',')
        ra = float(ra)
        dec = float(dec)
        return dict(ra=ra, dec=dec, coordinate_system=coordinate_system)

    @staticmethod
    def _parse_line(splitted_line):
        rcus_mode = int(splitted_line['rcus_mode'])
        sub_band_ids = [int(sub_band) for sub_band in splitted_line['sub_band'].split(',')]
        mode_description = splitted_line['mode_description']
        rcus_involved = splitted_line['rcus']
        beamlets = splitted_line['beamlets']

        station_pointing = HolographySpecification._parse_pointing(
            splitted_line['station_pointing'])
        virtual_pointing = HolographySpecification._parse_pointing(
            splitted_line['virtual_pointing'])
        if len(sub_band_ids) == 1:
            station_type = 'target'
        else:
            station_type = 'reference'
        station_name = splitted_line['station_name']

        return dict(station_name=station_name,
                    rcus_mode=rcus_mode,
                    sub_band_ids=sub_band_ids,
                    mode_description=mode_description,
                    rcus_involved=rcus_involved,
                    beamlets=beamlets,
                    station_pointing=station_pointing,
                    virtual_pointing=virtual_pointing,
                    station_type=station_type)

    def _parse_lines(self, lines):
        splitted_lines = HolographySpecification._split_lines(lines)

        for line in splitted_lines:
            parsed_line = HolographySpecification._parse_line(line)

            self.station_specification_map[parsed_line['station_name']] += [parsed_line]

    def _update_class_attributes(self):
        self.station_names = self.station_specification_map.keys()
        self.reference_station_names = [station_name for station_name in
                                         self.station_specification_map
                                         if len(self.station_specification_map[station_name]) == 1]
        self.target_station_names = [station_name for station_name in
                                        self.station_specification_map
                                        if len(self.station_specification_map[station_name]) > 1]
        print(self.target_station_names, self.reference_station_names)

    def read_file(self):
        lines = self._read_lines()
        self._parse_header(lines[0])
        self._parse_lines(lines[1:])
        self._update_class_attributes()

    @staticmethod
    def create_hs_list_from_name_list_and_path(name_list, path):
        return [HolographySpecification(name, path) for name in name_list]

    @staticmethod
    def is_holography_specification_file_name(name):
        return re.match(HolographySpecification.hs_name_pattern, name) is not None

    @staticmethod
    def extract_id_date_comment_from_name(name):
        match = re.match(HolographySpecification.hs_name_pattern, name)
        print(name, match)
        date = match.group('date')
        hs_id = int(match.group('id'))
        comment = match.group('comment')
        date = datetime.datetime.strptime(date, '%Y%m%d')
        return hs_id, date, comment


class MeasurementSet(object):
    ms_name_pattern = r'L(?P<sas_id>\d{6})_SB(?P<sub_band_id>\d{3})_uv\.MS'

    def __init__(self, ms_name, ms_path):
        self.path = os.path.join(ms_path, ms_name)

        if not MeasurementSet.is_a_valid_ms_name(ms_name):
            raise ValueError('The measurement set located in %s has not a valid name' % self.path,)

        self.name = ms_name
        self.sas_id, self.beamlet = \
                MeasurementSet.parse_sas_id_and_sub_band_from_ms_name(self.name)


    def get_data_table(self):
        data_table = MS_Table(self.path)
        return data_table

    def get_antenna_table(self):
        antenna_table = MS_Table(self.path + '/ANTENNA')
        return antenna_table

    def read_cross_correlation_per_station_names(self, reference, target):

        data_table = self.get_data_table()
        antennas_table = self.get_antenna_table()
        antenna_name_id_map = {name:i for i, name in enumerate(antennas_table.getcol('NAME'))}
        antenna1_list = data_table.getcol('ANTENNA1')
        antenna2_list = data_table.getcol('ANTENNA2')
        timestamp = data_table.getcol('TIME')
        cross_correlation = data_table.getcol('DATA')
        reference_antenna_id = antenna_name_id_map[reference]
        target_antenna_id = antenna_name_id_map[target]

        selected_data = [index for index, (a_i, a_j) in enumerate(zip(antenna1_list, antenna2_list))
                         if a_i == reference_antenna_id and a_j == target_antenna_id]

        return (timestamp[selected_data], cross_correlation[selected_data])


    def __repr__(self):
        return 'MeasurementSet(%d) located in %s for sas_id %d and sub_band_id %d' % (id(self),
                                                                                  self.name,
                                                                                  self.sas_id,
                                                                                  self.beamlet)

    @staticmethod
    def create_ms_dict_from_ms_name_list_and_path(list_of_ms_names, path):
        filtered_list_of_ms_names = MeasurementSet.filter_valid_ms_names(list_of_ms_names)
        ms_list = [MeasurementSet(ms_name, path) for ms_name in filtered_list_of_ms_names]
        return {ms.beamlet:ms for ms in ms_list}

    @staticmethod
    def parse_sas_id_and_sub_band_from_ms_name(ms_name):
        if not MeasurementSet.is_a_valid_ms_name(ms_name):
            raise ValueError('The measurement set %s has not a valid name' % ms_name,)
        match = re.match(MeasurementSet.ms_name_pattern, ms_name)
        return (int(match.group('sas_id')), int(match.group('sub_band_id')))


    @staticmethod
    def is_a_valid_ms_name(ms_name):
        pattern = MeasurementSet.ms_name_pattern
        return re.match(pattern, ms_name.strip()) # is not None

    @staticmethod
    def filter_valid_ms_names(list_of_ms_names):
        return list(filter(MeasurementSet.is_a_valid_ms_name, list_of_ms_names))


class HolographyDataset():
    def __init__(self):
        self.rcu_list = ()  # array of ints
        self.mode = None  # integer
        self.sas_ids = ()  # array of strings
        self.target_station_name = None  # string
        self.target_station_position = None  # list of 3 floats
        self.source_name = None  # string
        self.source_position = None  # list of 3 floats
        self.start_time = None  # date time when the observation started in MJD (float)
        self.end_time = None  # date time when the observation started in MJD (float)
        self.rotation_matrix = None  # array(3,3), translation matrix for
        # (RA, DEC) <-> (l, m) conversion
        self.antenna_field_position = ()  # coordinates of the antenna position in the target
        # station
        self.reference_stations = ()  # list of reference station names
        self.frequencies = ()  # list of frequencies
        self.ra_dec = () # array(Nfrequency, Nbeamlets, 2) contains the ra_dec of which a beam
        # points at given a frequency and a beamlet number
        self.data = () # array(NreferenceStations, Nfrequencies, Nbeamlets) that contains the
        # 4 polarization crosscorrelation for the 4 polarizations, the l and m coordinates, and
        # the timestamp in mjd of the sample

    def load_from_specification_and_ms(self, station_name, hb_specifications, h_measurement_set):
        """
        Loads the dataset from the specification files and the measurements for the given station
        name
        :param station_name: target station name
        :param hb_specifications: list of holography beam specification files
        :param h_measurement_set: map of the measurement set indexed by beamlet number
        """


    @staticmethod
    def load_from_file(path):
        """
        It reads the dataset from file and returns a HolographyDataset class
        :param path: path to file
        :return: the read dataset
        """
        result = HolographyDataset()

        raise NotImplementedError

    def store_to_file(self, path):
        """
        Stores the holography dataset at the given path
        :param path: path to file
        """

        raise NotImplementedError