Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
holography_measurementset.py 13.27 KiB
import os
import re
from enum import Enum

import numpy
from astropy.time import Time
from casacore.tables import table as MS_Table
from lofar.calibration.common.coordinates import pqr_from_icrs

from .holography_dataset_definitions import *


class CASA_POLARIZATION_INDEX(Enum):
    XX = 0
    XY = 1
    YX = 2
    YY = 3
    X = 0
    Y = 0


class HolographyMeasurementSet(object):
    ms_name_pattern = r'L(?P<sas_id>\d{6})_SB(?P<sub_band_id>\d{3})_uv\.MS'

    def __init__(self, ms_name, ms_path):
        self.path = os.path.join(ms_path, ms_name)

        if HolographyMeasurementSet.is_a_valid_ms_name(ms_name):
            self.name = ms_name
            self.sas_id, self.beamlet = \
                HolographyMeasurementSet.parse_sas_id_and_sub_band_from_ms_name(self.name)
        else:
            raise ValueError('The measurement set located in %s has not a valid name' % self.path, )

    def get_frequency(self):
        observation_table = self.get_spectral_window_table()
        try:
            reference_frequency = observation_table.getcol('REF_FREQUENCY')[0]
        finally:
            observation_table.close()

        return reference_frequency

    def get_subband(self):
        """
        Return the sub band associated to this measurement set
        :return: sub band number
        :rtype: int
        """
        observation_table = self.get_observation_table()
        try:
            clock = observation_table.getcol('LOFAR_CLOCK_FREQUENCY')[0]
            central_frequency = observation_table.getcol('LOFAR_OBSERVATION_FREQUENCY_CENTER')[0]

            bit_sampling = 1024
            subband = int(round(central_frequency / clock * bit_sampling) % 512)

        finally:
            observation_table.close()
        return subband

    def get_data_table(self):
        data_table = MS_Table(self.path, ack=False, readonly=True)

        return data_table

    def get_pointing_table(self):
        pointing_table = MS_Table(self.path + '/POINTING', ack=False, readonly=True)
        return pointing_table

    def get_antenna_table(self):
        antenna_table = MS_Table(self.path + '/ANTENNA', ack=False, readonly=True)
        return antenna_table

    def get_spectral_window_table(self):
        antenna_table = MS_Table(self.path + '/SPECTRAL_WINDOW', ack=False, readonly=True)
        return antenna_table

    def get_lofar_antenna_field_table(self):
        antenna_field_table = MS_Table(self.path + '/LOFAR_ANTENNA_FIELD', ack=False, readonly=True)
        return antenna_field_table

    def get_station_position_tile_offsets_and_axes_coordinate_for_station_name(self, station_name):
        antenna_table = self.get_antenna_table()
        antenna_field_table = self.get_lofar_antenna_field_table()
        try:
            station_name_index = antenna_table.index('NAME').rownr(station_name)

            station_position = antenna_field_table.getcell('POSITION', station_name_index)
            tile_offsets = antenna_field_table.getcell('ELEMENT_OFFSET', station_name_index)
            tiles_not_used = antenna_field_table.getcell('ELEMENT_FLAG', station_name_index)
            index_tiles_used = numpy.where(
                (tiles_not_used[:, CASA_POLARIZATION_INDEX.CASA_X_INDEX] == False) &
                (tiles_not_used[:, CASA_POLARIZATION_INDEX.CASA_Y_INDEX] == False))[0]
            tile_offsets = tile_offsets[index_tiles_used, :]

            axes_coordinate = antenna_field_table.getcell('COORDINATE_AXES', station_name_index)

        finally:
            antenna_table.close()
            antenna_field_table.close()
        return station_position, tile_offsets, axes_coordinate

    def __extract_source_name_from_pointing(self):
        pointing_table = self.get_pointing_table()
        try:
            unique_names = {name for name in pointing_table.getcol('NAME')}

            if len(unique_names) == 1:
                source_name = unique_names.pop()
            else:
                raise ValueError('Expected only a source as a target for the observation')
        finally:
            pointing_table.close()

        return source_name

    def get_observation_table(self):
        observation_table = MS_Table(self.path + '/OBSERVATION', ack=False)
        return observation_table

    def get_start_end_observation(self):
        observation_table = self.get_observation_table()

        try:
            time_range = observation_table.getcol('TIME_RANGE')[0]

            start_time, end_time = time_range

        finally:
            observation_table.close()

        return start_time, end_time

    def get_source_name(self):
        return self.__extract_source_name_from_pointing()

    def read_cross_correlation_time_flags_per_station_names(self, target_station,
                                                            reference_stations):
        """
        Read the crosscorrelation for a given station name and a list of reference stations

        :param reference_stations: list of reference station names to extract
        :type reference_stations: list[str]
        :param target_station: name of the target station
        :type target_station: str
        :return: the reference station names extracted and the beam crosscorrelation array
        :rtype: list[str], numpy.ndarray
        """
        data_table = self.get_data_table()
        antennas_table = self.get_antenna_table()

        try:
            baseline_selection = ','.join(reference_stations)
            baseline_selection += '&' + target_station

            TAQL_query_syntax = 'mscal.baseline($baseline_selection)'

            table = data_table.query(TAQL_query_syntax, columns='TIME,'
                                                                'FLAG_ROW,'
                                                                'DATA,'
                                                                'ANTENNA1,'
                                                                'ANTENNA2,'
                                                                'mscal.ant1name() as ANTENNA_NAME1,'
                                                                'mscal.ant2name() as ANTENNA_NAME2',
                                     sortlist='TIME, ANTENNA1, ANTENNA2')

            timestamps = list(table.getcol('TIME'))

            flags = table.getcol('FLAG_ROW')
            crosscorrelations = numpy.squeeze(table.getcol('DATA'))

            n_reference_stations = len(reference_stations)
            antenna1 = table.getcol('ANTENNA_NAME1')[:n_reference_stations]
            antenna2 = table.getcol('ANTENNA_NAME2')[:n_reference_stations]

            reference_station_names = [a2 if a1 == target_station else a1
                                       for a1, a2 in zip(antenna1, antenna2)]

            timestamps = timestamps[::n_reference_stations]

            n_timestamps = len(timestamps)
            n_polarizations = crosscorrelations.shape[-1]

            crosscorrelations = crosscorrelations.reshape([n_reference_stations,
                                                           n_timestamps,
                                                           n_polarizations], order='F')
            flags = flags.reshape([n_reference_stations, n_timestamps], order='F')

            beams_crosscorrelations_array = numpy.full([n_reference_stations, n_timestamps],
                                                       fill_value=-999999.9,
                                                       dtype=HDS_data_sample_type)

            for reference_station_index, reference_station_name in enumerate(reference_stations):
                beams_crosscorrelations_array[reference_station_index, :]['t'] = timestamps

                beams_crosscorrelations_array[reference_station_index, :]['XX'] = \
                    crosscorrelations[
                    reference_station_index, :, CASA_POLARIZATION_INDEX.CASA_XX_INDEX]
                beams_crosscorrelations_array[reference_station_index, :]['XY'] = \
                    crosscorrelations[
                    reference_station_index, :, CASA_POLARIZATION_INDEX.CASA_XY_INDEX]
                beams_crosscorrelations_array[reference_station_index, :]['YX'] = \
                    crosscorrelations[
                    reference_station_index, :, CASA_POLARIZATION_INDEX.CASA_YX_INDEX]
                beams_crosscorrelations_array[reference_station_index, :]['YY'] = \
                    crosscorrelations[
                    reference_station_index, :, CASA_POLARIZATION_INDEX.CASA_YY_INDEX]

                beams_crosscorrelations_array[reference_station_index, :]['flag'] = \
                    flags[reference_station_index, :]


        finally:
            data_table.close()
            antennas_table.close()

        return reference_station_names, beams_crosscorrelations_array

    def read_cross_correlation_time_flags_lm_per_station_name(self, target_station,
                                                              reference_stations,
                                                              pointing,
                                                              rotation_matrix):
        """
        Read the cross correlation for a given station name and a list of reference stations
        and computes the l and m for the given pointing
        :param reference_stations: list of reference station names to extract
        :type reference_stations: list[str]
        :param target_station: name of the target station
        :type target_station: str
        :param pointing: ra dec and epoch for the current pointing
        :type pointing: something
        :param: rotation_matrix: matrix to rotate the station frame to the earth frame
        :type rotation_matrix: numpy.ndarray
        :return: the reference station names extracted and the beam crosscorrelation array
        :rtype: list[str], numpy.ndarray
        """
        reference_station_names, beam_crosscorrelations_array = \
            self.read_cross_correlation_time_flags_per_station_names(target_station,
                                                                     reference_stations)

        timestamps = beam_crosscorrelations_array[0, :]['t']
        lm_for_target_station = HolographyMeasurementSet._compute_lm_from_ra_dec_station_position_rotation_matrix_and_time(
            pointing, rotation_matrix, timestamps)
        for reference_stations_index in range(beam_crosscorrelations_array.shape[0]):
            beam_crosscorrelations_array['l'][reference_stations_index, :] = \
                lm_for_target_station['l']
            beam_crosscorrelations_array['m'][reference_stations_index, :] = \
                lm_for_target_station['m']

        return reference_station_names, beam_crosscorrelations_array

    @staticmethod
    def __mjd_to_astropy_time(mjd_time_seconds):
        """
        Convert the modified julian date in seconds in a datetime object
        :param mjd_time_seconds: modified julian data in seconds
        :return: the date time of the given julian date
        :rtype: datetime
        """
        hour_in_seconds = 60 * 60
        day_in_seconds = hour_in_seconds * 24

        return Time(mjd_time_seconds / day_in_seconds, format='mjd', scale='utc')

    @staticmethod
    def _compute_lm_from_ra_dec_station_position_rotation_matrix_and_time(ra_dec_epoch,
                                                                          rotation_matrix,
                                                                          mjd_times):
        if isinstance(ra_dec_epoch, numpy.ndarray):
            ra, dec, epoch = ra_dec_epoch.tolist()

            astropy_times = [HolographyMeasurementSet.__mjd_to_astropy_time(mjd_time)
                             for mjd_time in mjd_times]
            n_samples = len(astropy_times)
            return_value_dtype = [('l', numpy.float64),
                                  ('m', numpy.float64)]

            return_value = numpy.empty(n_samples, dtype=return_value_dtype)
            l_m_arrays = pqr_from_icrs(numpy.array((ra, dec)), astropy_times, rotation_matrix)

            return_value['l'][:] = l_m_arrays[:, 0]
            return_value['m'][:] = l_m_arrays[:, 1]
        else:
            raise TypeError('Expected a structured numpy array for ra_dec obtained {}'.
                            format(ra_dec_epoch))

        return return_value

    def __repr__(self):
        return 'MeasurementSet(%d) located in %s for sas_id %s and sub_band_id %d' % (id(self),
                                                                                      self.name,
                                                                                      self.sas_id,
                                                                                      self.beamlet)

    @staticmethod
    def parse_sas_id_and_sub_band_from_ms_name(ms_name):
        if HolographyMeasurementSet.is_a_valid_ms_name(ms_name):
            match = re.match(HolographyMeasurementSet.ms_name_pattern, ms_name)
        else:
            raise ValueError('The measurement set %s has not a valid name' % ms_name, )
        return str(match.group('sas_id')), int(match.group('sub_band_id'))

    @staticmethod
    def is_a_valid_ms_name(ms_name):
        pattern = HolographyMeasurementSet.ms_name_pattern
        return re.match(pattern, ms_name.strip())  # is not None

    @staticmethod
    def filter_valid_ms_names(list_of_ms_names):
        return list(filter(HolographyMeasurementSet.is_a_valid_ms_name, list_of_ms_names))