Skip to content
Snippets Groups Projects
Commit 01cbb695 authored by Mattia Mancini's avatar Mattia Mancini
Browse files

SSB-47: various changes (see extended description)

- included function to load from hdf5 file
- error handling when loading from file
- frequency axis generation
- derive new calibration table from existing one
- relative tests
parent 58457b05
No related branches found
No related tags found
1 merge request!44Merge back holography to master
...@@ -9,7 +9,9 @@ from typing import List ...@@ -9,7 +9,9 @@ from typing import List
from dataclasses import dataclass, asdict, field from dataclasses import dataclass, asdict, field
from h5py import File from h5py import File
from numpy import empty as empty_ndarray, ndarray, fromiter as array_from_iter, float64, array_equal from numpy import empty as empty_ndarray, ndarray, fromiter as array_from_iter, float64,\
array_equal, arange, array
from copy import deepcopy
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -19,7 +21,9 @@ __FREQUENCIES = 512 ...@@ -19,7 +21,9 @@ __FREQUENCIES = 512
__FLOATS_PER_FREQUENCY = 2 __FLOATS_PER_FREQUENCY = 2
__N_ANTENNAS_DUTCH = 96 __N_ANTENNAS_DUTCH = 96
__N_ANTENNAS_INTERNATIONAL = 192 __N_ANTENNAS_INTERNATIONAL = 192
__CALIBRATION_TABLE_FILENAME_PATTERN = '*CalTable-???-???-???_???.dat' __CALIBRATION_TABLE_FILENAME_PATTERN = '**/*CalTable-???-???-???_???.dat'
_MODE_TO_CLOCK = {1: 200, 3: 200, 5: 200, 6: 160, 7: 200}
_MODE_TO_NYQ_ZONE = {1: 1, 3: 1, 5: 2, 6: 1, 7: 3}
_ATTRIBUTE_NAME_TO_SERIALIZED_NAME = { _ATTRIBUTE_NAME_TO_SERIALIZED_NAME = {
'observation_station': 'CalTableHeader.Observation.Station', 'observation_station': 'CalTableHeader.Observation.Station',
...@@ -41,6 +45,7 @@ class UnvalidFileException(Exception): ...@@ -41,6 +45,7 @@ class UnvalidFileException(Exception):
self.message = message self.message = message
def _extract_header(fstream: BinaryIO): def _extract_header(fstream: BinaryIO):
header = {} header = {}
for i in range(__MAX_HEADER_LINES): for i in range(__MAX_HEADER_LINES):
...@@ -82,12 +87,10 @@ def parse_data(data_buffer): ...@@ -82,12 +87,10 @@ def parse_data(data_buffer):
return complex_data return complex_data
@dataclass(init=True, repr=True, frozen=False, eq=True) @dataclass(init=True, repr=True, frozen=False)
class CalibrationTable: class CalibrationTable:
observation_station: str observation_station: str
observation_mode: int observation_mode: int
observation_antennaset: str
observation_band: str
observation_source: str observation_source: str
observation_date: str observation_date: str
calibration_version: int calibration_version: int
...@@ -97,6 +100,9 @@ class CalibrationTable: ...@@ -97,6 +100,9 @@ class CalibrationTable:
data: ndarray = field(compare=False) data: ndarray = field(compare=False)
comment: str = '' comment: str = ''
observation_antennaset: str = ''
observation_band: str = ''
def __parse_attributes(self): def __parse_attributes(self):
self.observation_mode = int(self.observation_mode) self.observation_mode = int(self.observation_mode)
self.calibration_version = int(self.calibration_version) self.calibration_version = int(self.calibration_version)
...@@ -110,14 +116,41 @@ class CalibrationTable: ...@@ -110,14 +116,41 @@ class CalibrationTable:
def __post_init__(self): def __post_init__(self):
self.__parse_attributes() self.__parse_attributes()
def frequencies(self) -> ndarray:
subbands = arange(1, 513, 1.)
clock = _MODE_TO_CLOCK[self.observation_mode]
nyquist_zone = _MODE_TO_NYQ_ZONE[self.observation_mode]
frequencies = subbands * clock / 1024. + (nyquist_zone - 1) * clock / 2.
return frequencies
def derive_calibration_table_from_gain_fit(self,
observation_source: str,
observation_date: str,
calibration_name: str,
commment: str,
gains):
new_calibration_table = deepcopy(self)
new_calibration_table.observation_source = observation_source
new_calibration_table.observation_date = observation_date
new_calibration_table.calibration_name = calibration_name,
new_calibration_table.comment = commment
new_calibration_table.data = gains
@staticmethod @staticmethod
def load_from_file(file_path): def load_from_file(file_path):
logger.info('loading file %s', file_path)
with open(file_path, 'rb') as file_stream: with open(file_path, 'rb') as file_stream:
header = _extract_header(file_stream) header = _extract_header(file_stream)
data_raw = file_stream.read() data_raw = file_stream.read().rstrip(b'\n')
try:
data = parse_data(data_raw) data = parse_data(data_raw)
except Exception as e:
logger.error('error reading file %s', file_path)
logger.debug(data_raw)
logger.exception(e)
raise e
calibration_table = CalibrationTable(**header, calibration_table = CalibrationTable(**header,
data=data) data=data)
return calibration_table return calibration_table
...@@ -158,6 +191,15 @@ class CalibrationTable: ...@@ -158,6 +191,15 @@ class CalibrationTable:
file_descriptor[uri].attrs[key] = value file_descriptor[uri].attrs[key] = value
file_descriptor.flush() file_descriptor.flush()
@staticmethod
def load_from_hdf(file_descriptor: File, uri: str):
if uri not in file_descriptor:
raise ValueError('specified uri does not exist in %s' % file_descriptor.filename)
data = array(file_descriptor[uri])
return CalibrationTable(data=data, **dict(file_descriptor[uri].attrs.items()))
def __eq__(self, other): def __eq__(self, other):
return super().__eq__(other) and array_equal(self.data, other.data) return super().__eq__(other) and array_equal(self.data, other.data)
...@@ -173,7 +215,7 @@ def read_calibration_tables_in_directory(directory_path: str): ...@@ -173,7 +215,7 @@ def read_calibration_tables_in_directory(directory_path: str):
files = path.join(directory_path, __CALIBRATION_TABLE_FILENAME_PATTERN) files = path.join(directory_path, __CALIBRATION_TABLE_FILENAME_PATTERN)
return [CalibrationTable.load_from_file(file_path) return [CalibrationTable.load_from_file(file_path)
for file_path in glob(files, recursive=False)] for file_path in glob(files, recursive=True)]
def read_calibration_tables_per_station_mode(directory_path: str) -> Dict[Tuple[str, int], def read_calibration_tables_per_station_mode(directory_path: str) -> Dict[Tuple[str, int],
......
...@@ -6,7 +6,7 @@ from os import getcwd ...@@ -6,7 +6,7 @@ from os import getcwd
import logging import logging
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from numpy.testing import assert_array_equal from numpy.testing import assert_array_equal
from numpy import array from numpy import array, linspace
CALIBRATION_TABLE_FILENAME = 't_calibration_table.in_CalTable-401-HBA-110_190.dat' CALIBRATION_TABLE_FILENAME = 't_calibration_table.in_CalTable-401-HBA-110_190.dat'
...@@ -49,6 +49,36 @@ class TestCalibrationTable(unittest.TestCase): ...@@ -49,6 +49,36 @@ class TestCalibrationTable(unittest.TestCase):
) )
assert_array_equal(array(h5_file['/calibration_table']), test_calibration_table.data) assert_array_equal(array(h5_file['/calibration_table']), test_calibration_table.data)
def test_loading_to_hdf(self):
with NamedTemporaryFile('w+b') as temp_file:
h5_file = H5File(temp_file.name, 'w')
test_calibration_table = CalibrationTable.load_from_file(CALIBRATION_TABLE_FILENAME)
test_calibration_table.store_to_hdf(h5_file, '/calibration_table')
h5_file.close()
h5_file = H5File(temp_file.name, 'r')
caltable = CalibrationTable.load_from_hdf(file_descriptor=h5_file,
uri='/calibration_table')
self.assertEqual(caltable, test_calibration_table)
def test_loading_to_hdf_raise_value_error(self):
with NamedTemporaryFile('w+b') as temp_file:
h5_file = H5File(temp_file.name, 'w')
h5_file.close()
h5_file = H5File(temp_file.name, 'r')
with self.assertRaises(ValueError):
caltable = CalibrationTable.load_from_hdf(file_descriptor=h5_file,
uri='/calibration_table')
def test_frequency_generation(self):
test_calibration_table = CalibrationTable.load_from_file(CALIBRATION_TABLE_FILENAME)
sb = linspace(1, 512, 512)
clock = 200
nyq = 2
expected_frequency_array = freq = clock / 1024 * sb + (nyq-1) * clock / 2.
assert_array_equal(expected_frequency_array, test_calibration_table.frequencies())
def test_list_calibration_tables_in_path(self): def test_list_calibration_tables_in_path(self):
calibration_tables = read_calibration_tables_in_directory(getcwd()) calibration_tables = read_calibration_tables_in_directory(getcwd())
self.assertEqual(len(calibration_tables), 1) self.assertEqual(len(calibration_tables), 1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment