Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
calibration_table.py 9.21 KiB
import logging
from glob import glob
from os import path
from re import fullmatch
from struct import iter_unpack, pack
from typing import BinaryIO
from typing import Dict, Tuple
from typing import List

from dataclasses import dataclass, asdict, field
from h5py import File
from numpy import empty as empty_ndarray, ndarray, fromiter as array_from_iter, float64, \
    array_equal, arange, array
from copy import deepcopy

logger = logging.getLogger(__name__)

_MODE_TO_CLOCK = {1: 200, 3: 200, 5: 200, 6: 160, 7: 200}
_MODE_TO_NYQ_ZONE = {1: 1, 3: 1, 5: 2, 6: 1, 7: 3}

_ATTRIBUTE_NAME_TO_SERIALIZED_NAME = {
    'observation_station': 'CalTableHeader.Observation.Station',
    'observation_mode': 'CalTableHeader.Observation.Mode',
    'observation_antennaset': 'CalTableHeader.Observation.AntennaSet',
    'observation_band': 'CalTableHeader.Observation.Band',
    'observation_source': 'CalTableHeader.Observation.Source',
    'observation_date': 'CalTableHeader.Observation.Date',
    'calibration_version': 'CalTableHeader.Calibration.Version',
    'calibration_name': 'CalTableHeader.Calibration.Name',
    'calibration_date': 'CalTableHeader.Calibration.Date',
    'calibration_ppsdelay': 'CalTableHeader.Calibration.PPSDelay',
    'comment': 'CalTableHeader.Comment'
}
_CALIBRATION_TABLE_FILENAME_PATTERN = '**/*CalTable_???_mode?.dat'


class InvalidFileException(Exception):
    def __init__(self, message):
        self.message = message


@dataclass(init=True, repr=True, frozen=False)
class CalibrationTable:
    _MAX_HEADER_LINES = 100
    _HEADER_LINE_PATTERN = r'(^[A-z]*\.[A-z]*\.[A-z]*\s=\s.*$)|(^[A-z]*\.[A-z]*\s=\s.*$)'
    _FREQUENCIES = 512
    _FLOATS_PER_FREQUENCY = 2
    _N_ANTENNAS_DUTCH = 96
    _N_ANTENNAS_INTERNATIONAL = 192

    observation_station: str
    observation_mode: int
    observation_source: str
    observation_date: str
    calibration_version: int
    calibration_name: str
    calibration_date: str
    calibration_ppsdelay: List[int]
    data: ndarray = field(compare=False)
    comment: str = ''

    observation_antennaset: str = ''
    observation_band: str = ''

    @staticmethod
    def load_from_file(file_path):
        logger.info('loading file %s', file_path)
        with open(file_path, 'rb') as file_stream:
            header = CalibrationTable._extract_header(file_stream)
            data_raw = file_stream.read().rstrip(b'\n')
        try:
            data = CalibrationTable._parse_data(data_raw)
        except Exception as e:

            logger.error('error reading file %s', file_path)
            logger.debug(data_raw)
            logger.exception(e)
            raise e
        calibration_table = CalibrationTable(**header,
                                             data=data)
        return calibration_table

    @staticmethod
    def load_from_hdf(file_descriptor: File, uri: str):
        if uri not in file_descriptor:
            raise ValueError('specified uri does not exist in %s' % file_descriptor.filename)

        data = array(file_descriptor[uri])

        return CalibrationTable(data=data, **dict(file_descriptor[uri].attrs.items()))

    def derive_calibration_table_from_gain_fit(self,
                                               observation_source: str,
                                               observation_date: str,
                                               calibration_name: str,
                                               calibration_date: str,
                                               comment: str,
                                               gains):
        new_calibration_table = deepcopy(self)
        new_calibration_table.observation_source = observation_source
        new_calibration_table.observation_date = observation_date
        new_calibration_table.calibration_name = calibration_name
        new_calibration_table.calibration_date = calibration_date
        new_calibration_table.comment = comment

        new_calibration_table.data *= gains
        return new_calibration_table

    def frequencies(self) -> ndarray:
        subbands = arange(1, 513, 1.)
        clock = _MODE_TO_CLOCK[self.observation_mode]
        nyquist_zone = _MODE_TO_NYQ_ZONE[self.observation_mode]
        frequencies = subbands * clock / 1024. + (nyquist_zone - 1) * clock / 2.
        return frequencies

    def store_to_hdf(self, file_descriptor: File, uri: str):
        if uri not in file_descriptor:
            file_descriptor[uri] = self.data
        for key, value in asdict(self).items():
            if key is 'data':
                # skipping field data
                continue
            file_descriptor[uri].attrs[key] = value
        file_descriptor.flush()

    def store_to_file(self, file_path):
        with open(file_path, 'wb') as file_stream:
            self._serialize_header(file_stream)
            self._serialize_data(file_stream)

    def _serialize_header(self, f_stream: BinaryIO):
        f_stream.write(b'HeaderStart\n')

        for key, value in asdict(self).items():
            if key is 'data':
                # skipping field data
                continue
            if key is 'calibration_ppsdelay':
                serialized_value = '[%s ]' % ' '.join(map(str, self.calibration_ppsdelay))
            else:
                serialized_value = str(value)
            serialized_name = _ATTRIBUTE_NAME_TO_SERIALIZED_NAME[key]
            serialized_line = '{} = {}\n'.format(serialized_name, serialized_value).encode('utf8')
            f_stream.write(serialized_line)

        f_stream.write(b'HeaderStop\n')

    def _serialize_data(self, f_stream: BinaryIO):
        dimensions = list(self.data.shape) + [2]
        data_reshaped = empty_ndarray(dimensions, dtype=float64)
        data_reshaped[:, :, 0] = self.data.real
        data_reshaped[:, :, 1] = self.data.imag
        data_flattened = data_reshaped.flatten()
        data_packed = pack('%sd' % len(data_flattened), *data_flattened)
        f_stream.write(data_packed)

    @staticmethod
    def _extract_header(fstream: BinaryIO):
        header = {}
        for i in range(CalibrationTable._MAX_HEADER_LINES):
            line = fstream.readline().decode('utf8').rstrip('\n')

            if line == 'HeaderStop':
                break
            elif line == 'HeaderStart':
                continue
            elif fullmatch(CalibrationTable._HEADER_LINE_PATTERN, line):

                key, value = line.split('=')

                key = key.lower().replace('caltableheader.', '').strip().replace('.', '_')
                value = value.strip()
                header[key] = value
            else:
                logger.error('unrecognized line \"%s\"', line)
                raise InvalidFileException('unrecognized line \"%s\"' % line)
        if len(header) == 0:
            raise InvalidFileException('empty header')
        return header

    def _parse_header(self):
        self.observation_mode = int(self.observation_mode)
        self.calibration_version = int(self.calibration_version)
        if isinstance(self.calibration_ppsdelay, str):
            self.calibration_ppsdelay = list(map(int, self.calibration_ppsdelay.
                                                 lstrip('[').
                                                 rstrip(']').
                                                 strip().
                                                 split(' ')))

    @staticmethod
    def _parse_data(data_buffer: bytes):
        data = array_from_iter(map(lambda x: x[0], iter_unpack('d', data_buffer)), dtype=float)

        n_antennas = data.shape[0] // CalibrationTable._FREQUENCIES // CalibrationTable._FLOATS_PER_FREQUENCY

        if n_antennas not in [CalibrationTable._N_ANTENNAS_DUTCH, CalibrationTable._N_ANTENNAS_INTERNATIONAL]:
            raise InvalidFileException('invalid data range expected %s or %s antennas got %s' %
                                       (CalibrationTable._N_ANTENNAS_DUTCH,
                                        CalibrationTable._N_ANTENNAS_INTERNATIONAL,
                                        n_antennas))

        data = data.reshape((CalibrationTable._FREQUENCIES, n_antennas, CalibrationTable._FLOATS_PER_FREQUENCY))
        complex_data = empty_ndarray([CalibrationTable._FREQUENCIES, n_antennas], dtype=complex)
        complex_data.real = data[:, :, 0]
        complex_data.imag = data[:, :, 1]

        return complex_data

    def __post_init__(self):
        self._parse_header()

    def __eq__(self, other):
        return super().__eq__(other) and array_equal(self.data, other.data)


def read_calibration_tables_in_directory(directory_path: str):
    if not path.isdir(directory_path):
        raise NotADirectoryError(directory_path)
    files = path.join(directory_path, _CALIBRATION_TABLE_FILENAME_PATTERN)

    return [CalibrationTable.load_from_file(file_path)
            for file_path in glob(files, recursive=True)]


def read_calibration_tables_per_station_mode(directory_path: str) -> Dict[Tuple[str, int],
                                                                          CalibrationTable]:
    list_of_calibration_tables = read_calibration_tables_in_directory(directory_path)
    result = dict()
    for calibration_table in list_of_calibration_tables:
        result[calibration_table.observation_station, calibration_table.observation_mode] = \
            calibration_table
    return result