Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
holography_dataset.py 19.37 KiB
from .holography_specification import HolographySpecification
from .holography_dataset_definitions import *

from lofar.calibration.common.datacontainers.holography_observation import HolographyObservation
import logging
import numpy
import h5py

logger = logging.getLogger(__name__)

class HolographyDataset():
    def __init__(self):
        # float, HDS version
        self.version = HOLOGRAPHY_DATA_SET_VERSION

        # list of ints
        self.rcu_list = list()

        # integer
        self.mode = None

        # list of strings
        self.sas_ids = list()

        # string
        self.target_station_name = None

        # list of 3 floats
        self.target_station_position = None
        self.source_name = None  # string
        self.source_position = None  # list of 3 floats
        self.start_time = None  # date time when the observation started in MJD in seconds (float)
        self.end_time = None  # date time when the observation started in MJD in seconds (float)
        self.rotation_matrix = None  # array(3,3), translation matrix for
        # (RA, DEC) <-> (l, m) conversion

        self.beamlets = list() # list of beamlet numbers
        # coordinates of the antenna position in the target
        self.antenna_field_position = []
        # station
        self.reference_stations = list()  # list of reference station names
        self.frequencies = list()  # list of frequencies
        self.ra_dec = dict() # array(Nfrequency, Nbeamlets) contains the ra_dec of which a beam
        # points at given a frequency and a beamlet number
        # numpy.dtype([('RA', numpy.float64),
        #              ('DEC',numpy.float64),
        #              ('EPOCH', 'S10')]) 
        self.data = dict() # array(NreferenceStations, Nfrequencies, Nbeamlets) that contains the
        # 4 polarization crosscorrelation for the 4 polarizations, the l and m coordinates, and
        # the timestamp in mjd of the sample, and whether or not the data has been flagged
        # numpy.dtype([('XX', numpy.float),
        #              ('YY', numpy.float),
        #              ('XY', numpy.float),
        #              ('YX', numpy.float),
        #              ('l', numpy.float),
        #              ('m', numpy.float),
        #              ('t', numpy.float),
        #              ('flag', numpy.bool)]
        #              )

    @staticmethod
    def compare_dicts(dict1, dict2):
        result = True
        for key in dict1.keys():
            if key in dict2.keys():
                if isinstance(dict1[key], dict) and isinstance(dict2[key], dict):
                    result = result and HolographyDataset.compare_dicts(dict1[key], dict2[key])
                else:
                    if isinstance(dict1[key], numpy.ndarray) and isinstance(dict2[key], numpy.ndarray):
                        # Compares element by element the two arrays
                        return numpy.array_equal(dict1[key], dict2[key])
                    else:
                        return dict1[key] == dict2[key]
            else:

                return False
        return result
    
    def __eq__(self, hds=None):
        equality = False
        if hds is not None and isinstance(hds, HolographyDataset) is True:
            equality = True
            for attribute_name, attribute_value in self.__dict__.items():
                other_value = getattr(hds, attribute_name)
                this_equality = True
                try:
                    if isinstance(attribute_value, numpy.ndarray) is True and isinstance(other_value, numpy.ndarray) is True:
                        this_equality = numpy.array_equal(attribute_value, other_value)
                    elif isinstance(attribute_value, dict) is True and isinstance(other_value, dict) is True:
                        this_equality = HolographyDataset.compare_dicts(attribute_value, other_value)
                    elif attribute_value != other_value:
                        this_equality = False
                except Exception as e:
                    print("***", attribute_name, type(attribute_value), e)

                if this_equality is False:
                    try:
                        print("###", attribute_name, type(attribute_value), type(other_value))
                        if attribute_name == "ra_dec":
                            print("ra_dec:", attribute_value, other_value)
                    except:
                        print("%%%", attribute_name, type(attribute_value), type(other_value))

                try:
                    equality = equality and this_equality
                except Exception as e:
                    print("&&&", attribute_name, type(attribute_value), type(other_value))
        return equality

    def load_from_beam_specification_and_ms(self, station_name, list_of_hbs_ms_tuples):
        """
        Loads the dataset from the specification files and the measurements for the given station
        name
        :param station_name: target station name
        :param hb_specifications: a list containing (hbs, ms) per frequency

        """
        self.__collect_preliminary_information(station_name, list_of_hbs_ms_tuples)
        self.__read_data(station_name, list_of_hbs_ms_tuples)


    def __read_data(self, station_name, list_of_hbs_ms_tuples):
        """

        :param station_name:
        :param list_of_hbs_ms_tuples:
        :type list_of_hbs_ms_tuples: list[(HolographySpecification, HolographyObservation)]
        :return:
        """

        self.data = dict()
        for hbs, ho in list_of_hbs_ms_tuples:
            if station_name in hbs.target_station_names:
                frequency = ho.frequency
                frequency_string = str(frequency)
                for beamlet in self.beamlets:
                    beamlet_string = str(beamlet)
                    reference_station_names, cross_correlation =\
                        ho.ms_for_a_given_beamlet_number[beamlet].\
                        read_cross_correlation_time_flags_per_station_names(station_name,
                                                                 self.reference_stations)

                    for reference_station_index, reference_station in\
                            enumerate(reference_station_names):

                        if reference_station not in self.data:
                            self.data[reference_station] = dict()

                        if frequency_string not in self.data[reference_station]:
                            self.data[reference_station][frequency_string] = dict()

                        self.data[reference_station][frequency_string][beamlet_string] = \
                            cross_correlation[reference_station_index, :]


    def __collect_preliminary_information(self, station_name, list_of_hbs_ho_tuples):
        """
        This routines reads both the holography beam specifications files and the holography
        observation to gather the list of rcus, the mode, the target station name and position,
        the source name and position, the start and the end time, the rotation matrix to convert
        from ra and dec to l and m, the antenna field positions, the list of the reference
        stations, the frequencies, the ra and dec at which the beams point at.

        All this data is essential to interpret the recorded holography beams cross correlations
        :param list_of_hbs_ho_tuples: a list containing (hbs, ho) per frequency
        :type list_of_hbs_ho_tuples: list[(HolographySpecification, HolographyObservation)]
        """
        mode = set()
        source_name = set()
        source_position = set()
        target_stations = set()
        reference_stations = set()
        beamlets = set()
        virtual_pointing = dict()
        frequencies = set()
        sas_ids = set()
        rcu_list = set()
        start_mjd = None
        end_mjd = None
        for hbs, ho in list_of_hbs_ho_tuples:

            target_stations.update(hbs.target_station_names)

            if station_name in hbs.target_station_names:
                beam_specifications = hbs.get_beam_specifications_per_station_name(station_name)
                for beam_specification in beam_specifications:
                    rcu_list.update(beam_specification.rcus_involved)

                    mode.add(beam_specification.rcus_mode)

                    source_name.add(ho.source_name)
                    source_position.add(
                        (beam_specification.station_pointing['ra'],
                         beam_specification.station_pointing['dec'],
                         beam_specification.station_pointing['coordinate_system']
                        ))
                    if start_mjd is None or start_mjd > ho.start_mjd:
                        start_mjd = ho.start_mjd

                    if end_mjd is None or end_mjd < ho.end_mjd:
                        end_mjd = ho.end_mjd

                    frequencies.add(ho.frequency)

                    sas_ids.add(ho.sas_id)

                    self.target_station_name = station_name
                    reference_stations.update(hbs.reference_station_names)
                    try:
                        single_beamlet = int(beam_specification.beamlets)
                    except ValueError as e:
                        logger.exception('Target station specification incorrect')
                        raise e

                    beamlets.add(single_beamlet)

                    virtual_pointing[(ho.frequency, single_beamlet)] =\
                        (beam_specification.virtual_pointing['ra'],
                         beam_specification.virtual_pointing['dec'],
                         beam_specification.virtual_pointing['coordinate_system'])
            else:
                continue
        self.frequencies = sorted(frequencies)
        self.beamlets = sorted(beamlets)
        self.start_time = start_mjd
        self.end_time = end_mjd
        self.sas_ids = list(sas_ids)
        self.reference_stations = list(reference_stations)
        self.rcu_list = list(rcu_list)

        self.ra_dec = dict()
        coordinate_type = numpy.dtype([('RA', numpy.float64), 
                                       ('DEC', numpy.float64),
                                       ('EPOCH', 'S10')])
             
        for frequency in self.frequencies:
            frequency_string = str(frequency)
            if frequency not in self.ra_dec:
                self.ra_dec[frequency_string] = dict()
            for beamlet in self.beamlets:
                beamlet_string = str(beamlet)
                self.ra_dec[frequency_string][beamlet_string] = numpy.array(
                    virtual_pointing[(frequency, beamlet)],
                    dtype=coordinate_type)


        # reads the target station position and the coordinate of its axes
        # and does this only once since the coordinate will not change
        first_holography_observation = list_of_hbs_ho_tuples[0][1]
        first_ms = first_holography_observation.ms_for_a_given_beamlet_number.values()[0]
        station_position, tile_offset, axes_coordinates = first_ms.\
            get_station_position_tile_offsets_and_axes_coordinate_for_station_name(
            station_name)

        self.antenna_field_position = [list(station_position - antenna_offset)
                                       for antenna_offset in tile_offset]
        self.target_station_position = list(station_position)
        self.rotation_matrix = axes_coordinates

        if station_name not in target_stations:
            logger.error('Station %s was not involved in the observation.'
                         ' The target stations for this observation are %s',
                         station_name, target_stations)
            raise Exception('Station %s was not involved in the observation.'
                            % station_name,)

        if len(mode) > 1:
            raise ValueError('Multiple RCUs mode are not supported')
        else:
            self.mode = mode.pop()

        if len(source_position) > 1:
            logger.error('Multiple source positions are not supported: %s', source_position)
            raise ValueError('Multiple source positions are not supported')
        else:
            self.source_position = source_position.pop()

        if len(source_name) > 1:
            raise ValueError('Multiple source name are not supported')
        else:
            self.source_name = source_name.pop()

    @staticmethod
    def print_info(hds, text = None):
        if text is not None and isinstance(text, str):
            print("%s" % (text))
        if hds is not None and isinstance(hds, HolographyDataset) is True:
            print("Version = ", hds.version)
            print("Mode = ", hds.mode)
            print("RCU list = ", hds.rcu_list)
            print("SAS ID = ", hds.sas_ids)
            print("Target station name = ", hds.target_station_name)
            print("Target station position = ", hds.target_station_position)
            print("Source name = ", hds.source_name)
            print("Source position = ", hds.source_position)
            print("Start time = ", hds.start_time)
            print("End time = ", hds.end_time)
            print("Rotation matrix = ", hds.rotation_matrix)
            print("Antenna field position = ", hds.antenna_field_position)
            print("Reference stations = ", hds.reference_stations)
            print("Frequencies = ", hds.frequencies)
            print("RA DEC = ", hds.ra_dec)
            print("Data = ", hds.data)
            print("Beamlets = ", hds.beamlets)
        else:
            print("The object is not a HolographyDataset instance.  Cannot print any data.")

    @staticmethod
    def load_from_file(path):
        """
        It reads a holography dataset from an HDF5 file and returns a
        HolographyDataset class
        :param path: path to file
        :return: the read dataset
        """
        f = None
        try:
            f = h5py.File(path, "r")

            result = HolographyDataset()
            result.version = f.attrs[HDS_VERSION]
            result.mode= f.attrs[HDS_MODE]
            result.rcu_list = list(f.attrs[HDS_RCU_LIST])
            result.sas_ids = list(f.attrs[HDS_SAS_ID])
            result.target_station_name = f.attrs[HDS_TARGET_STATION_NAME]
            result.target_station_position = list(f.attrs[HDS_TARGET_STATION_POSITION])
            result.source_name = f.attrs[HDS_SOURCE_NAME]
            result.source_position = tuple(f.attrs[HDS_SOURCE_POSITION].tolist())
            (result.start_time, result.end_time) = f.attrs[HDS_OBSERVATION_TIME]
            result.rotation_matrix = f.attrs[HDS_ROTATION_MATRIX]
            result.antenna_field_position = f.attrs[HDS_ANTENNA_FIELD_POSITION].tolist()
            result.reference_stations = list(f[HDS_REFERENCE_STATION])
            result.frequencies = list(f[HDS_FREQUENCY])

            result.ra_dec = dict()
            for frequency in f["RA_DEC"].keys():
                for beamlet in f["RA_DEC"][frequency].keys():

                    if frequency not in result.ra_dec:
                        result.ra_dec[frequency] = dict()

                    result.ra_dec[frequency][beamlet] = numpy.array(f["RA_DEC"][frequency][beamlet])

            beamlets = set()
            result.data = dict()
            for reference_station in f["CROSSCORRELATION"].keys():
                for frequency in f["CROSSCORRELATION"][reference_station].keys():
                    for beamlet in f["CROSSCORRELATION"][reference_station][frequency].keys():
                        beamlets.add(int(beamlet))
                        if reference_station not in result.data:
                            result.data[reference_station] = dict()
                        if frequency not in result.data[reference_station]:
                            result.data[reference_station][frequency] = dict()
                        result.data[reference_station][frequency][beamlet] = numpy.array(f["CROSSCORRELATION"][reference_station][frequency][beamlet])

            result.beamlets = list(beamlets)
        except Exception as e:
            logger.exception("Cannot read the Holography Data Set data from the HDF5 file \"%s\".  This is the exception that was thrown:  %s", path, e)
            raise e
        finally:
            if f is not None:
                f.close()

        return result

    def store_to_file(self, path):
        """
        Stores the holography dataset at the given path
        :param path: path to file
        """
        f = None
        # Prepare the HDF5 data structs.
        try:
            f = h5py.File(path, "w")

            # Create the ATTRS
            # Start with the version information
            f.attrs[HDS_VERSION] = HOLOGRAPHY_DATA_SET_VERSION

            # RCU list
            f.attrs[HDS_RCU_LIST] = numpy.array(self.rcu_list, dtype=int)

            # RCU mode
            f.attrs[HDS_MODE] = self.mode

            # Moan...  Again this needs to be stored like that.
            f.attrs[HDS_SAS_ID] = numpy.array(self.sas_ids,
                                              dtype = h5py.special_dtype(vlen=str))
            f.attrs[HDS_TARGET_STATION_NAME] = self.target_station_name
            f.attrs[HDS_TARGET_STATION_POSITION] = self.target_station_position
            f.attrs[HDS_SOURCE_NAME] = self.source_name
            
            coordinate_type = numpy.dtype([('RA', numpy.float64), 
                                           ('DEC', numpy.float64),
                                           ('EPOCH', 'S10')])
            
            f.attrs[HDS_SOURCE_POSITION] = numpy.array(self.source_position,
                                                       dtype=coordinate_type)
            f.attrs[HDS_OBSERVATION_TIME] = numpy.array([self.start_time, self.end_time])
            f.attrs[HDS_ROTATION_MATRIX] = self.rotation_matrix
            f.attrs[HDS_ANTENNA_FIELD_POSITION] = self.antenna_field_position

            # Store the list of reference stations and frequencies.  We just
            # want to keep 'em around for quick reference.
            f[HDS_REFERENCE_STATION] = self.reference_stations
            f[HDS_FREQUENCY] = self.frequencies

            # We create groups for the reference stations and the frequencies.
            # Then we store the data samples [XX, YY, XY, YX, t, l, m, flag]
            # in an array.  The reference station name, the frequency and the
            # beamlet number (index of the data sample array) allow random
            # access of the data.
            
            f.create_group("RA_DEC")
            for frequency in self.ra_dec.keys():
                f["RA_DEC"].create_group(frequency)
                for beamlet in self.ra_dec[frequency].keys():
                    f["RA_DEC"][frequency][beamlet] = self.ra_dec[frequency][beamlet]

            f.create_group("CROSSCORRELATION")
            for reference_station in self.data.keys():
                f["CROSSCORRELATION"].create_group(reference_station)
                for frequency in self.data[reference_station].keys():
                    f["CROSSCORRELATION"][reference_station].create_group(frequency)
                    for beamlet in self.data[reference_station][frequency].keys():
                        
                        f["CROSSCORRELATION"][reference_station][frequency][beamlet] = self.data[reference_station][frequency][beamlet]
        except Exception as e:
            logger.exception("Cannot write the Holography Data Set data to the HDF5 file \"%s\".  This is the exception that was thrown:  %s", path, e)
            raise e
        finally:
            if f is not None:
                f.close()