diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 08d581be103c3f99247b7200ce294e30aa7880a6..1c38e70837b4152be3d0c0cd181a2aeed30deaa9 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -10,8 +10,6 @@ default: stages: - prepare - lint - # check if this needs to be a separate step - # - build_extensions - test - package - images @@ -34,29 +32,12 @@ trigger_prepare: strategy: depend include: .prepare.gitlab-ci.yml -run_black: +run_lint: stage: lint script: - - tox -e black + - tox -e lint allow_failure: true -run_flake8: - stage: lint - script: - - tox -e pep8 - allow_failure: true - -run_pylint: - stage: lint - script: - - tox -e pylint - allow_failure: true - -# build_extensions: -# stage: build_extensions -# script: -# - echo "build fortran/c/cpp extension source code" - sast: variables: SAST_EXCLUDED_ANALYZERS: brakeman, flawfinder, kubesec, nodejs-scan, phpcs-security-audit, @@ -89,7 +70,7 @@ run_unit_tests: - tox -e py3${PY_VERSION} parallel: matrix: # use the matrix for testing - - PY_VERSION: [9, 10, 11, 12, 13] + - PY_VERSION: [10, 11, 12, 13] # Run code coverage on the base image thus also performing unit tests run_unit_tests_coverage: @@ -123,31 +104,6 @@ package_docs: script: - tox -e docs -docker_build: - stage: images - image: docker:latest - needs: - - package_files - tags: - - dind - before_script: [] - script: - - docker login -u $CI_REGISTRY_USER -p $CI_REGISTRY_PASSWORD $CI_REGISTRY - - docker build -f docker/lofar_lotus/Dockerfile . --build-arg BUILD_ENV=copy --tag $CI_REGISTRY_IMAGE/lofar_lotus:$CI_COMMIT_REF_SLUG - # enable this push line once you have configured docker registry cleanup policy - # - docker push $CI_REGISTRY_IMAGE/lofar_lotus:$CI_COMMIT_REF_SLUG - -run_integration_tests: - stage: integration - allow_failure: true - needs: - - package_files - script: - - echo "make sure to move out of source dir" - - echo "install package from filesystem (or use the artefact)" - - echo "run against foreign systems (e.g. databases, cwl etc.)" - - exit 1 - publish_on_gitlab: stage: publish environment: gitlab @@ -212,14 +168,3 @@ publish_to_readthedocs: script: - echo "scp docs/* ???" - exit 1 - -release_job: - stage: publish - image: registry.gitlab.com/gitlab-org/release-cli:latest - rules: - - if: '$CI_COMMIT_TAG && $CI_COMMIT_REF_PROTECTED == "true"' - script: - - echo "running release_job" - release: - tag_name: '$CI_COMMIT_TAG' - description: '$CI_COMMIT_TAG - $CI_COMMIT_TAG_MESSAGE' diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2980237ca655f4a550d2b1e8eee9db66c321417f..ce46e6463265d2ad9704a24ae7f467cb7b50d8d9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,4 +1,4 @@ -default_stages: [ commit, push ] +default_stages: [ pre-commit, pre-push ] default_language_version: python: python3 exclude: '^docs/.*\.py$' @@ -14,25 +14,9 @@ repos: - id: detect-private-key - repo: local hooks: - - id: tox-black - name: tox-black (local) + - id: tox-lint + name: tox-lint (local) entry: tox language: python types: [file, python] - args: ["-e", "black", "--"] - - repo: local - hooks: - - id: tox-pep8 - name: tox-pep8 (local) - entry: tox - language: python - types: [file, python] - args: ["-e", "pep8", "--"] - - repo: local - hooks: - - id: tox-pylint - name: tox-pylint (local) - entry: tox - language: python - types: [file, python] - args: ["-e", "pylint", "--"] + args: ["-e", "lint", "--"] diff --git a/README.md b/README.md index 1125bfe071b1db4b6c2b9f74ed47fbc062713139..45165a71895eb7da86a1914884439bae5b6338c5 100644 --- a/README.md +++ b/README.md @@ -4,36 +4,38 @@  <!--  --> -An example repository of an CI/CD pipeline for building, testing and publishing a python package. +Common library containing various stuff for LOFAR2. ## Installation -``` -pip install . -``` -## Setup +Wheel distributions are available from the [gitlab package registry](https://git.astron.nl/lofar2.0/lotus/-/packages/), +install using after downloading: -One time template setup should include configuring the docker registry to regularly cleanup old images of -the CI/CD pipelines. And you can consider creating protected version tags for software releases: +```shell +python -m pip install *.whl +``` -1. [Cleanup Docker Registry Images](https://git.astron.nl/groups/templates/-/wikis/Cleanup-Docker-Registry-Images) -2. [Setup Protected Verson Tags](https://git.astron.nl/groups/templates/-/wikis/Setting-up-Protected-Version-Tags) +Alternatively install latest version on master using: -Once the cleanup policy for docker registry is setup you can uncomment the `docker push` comment in the `.gitlab-ci.yml` -file from the `docker_build` job. This will allow to download minimal docker images with your Python package installed. +```shell +python -m pip install lofar-lotus@git+https://git.astron.nl/lofar2.0/lotus +``` -## Usage -```python -from lofar_lotus import cool_module +Or install directly from the source at any branch or commit: -cool_module.greeter() # prints "Hello World" +```shell +python -m pip install ./ ``` +## Usage + +For more thorough usage explanation please consult the documentation + ## Development ### Development environment -To setup and activte the develop environment run ```source ./setup.sh``` from within the source directory. +To set up and activate the develop environment run ```source ./setup.sh``` from within the source directory. If PyCharm is used, this only needs to be done once. Afterward the Python virtual env can be setup within PyCharm. @@ -46,9 +48,9 @@ should be assigned. Verify your changes locally and be sure to add tests. Verifying local changes is done through `tox`. -```pip install tox``` +```python -m pip install tox``` -With tox the same jobs as run on the CI/CD pipeline can be ran. These +With tox the same jobs as run on the CI/CD pipeline can be executed. These include unit tests and linting. ```tox``` diff --git a/lofar_lotus/__init__.py b/lofar_lotus/__init__.py index f9a997131221a86d5539c008e7f34961146c66a5..43770c97ed04afbeb3ed6c7537dc5a093a801ef0 100644 --- a/lofar_lotus/__init__.py +++ b/lofar_lotus/__init__.py @@ -1,7 +1,7 @@ # Copyright (C) 2023 ASTRON (Netherlands Institute for Radio Astronomy) # SPDX-License-Identifier: Apache-2.0 -""" LOFAR LOTUS """ +"""LOFAR LOTUS""" try: from importlib import metadata diff --git a/lofar_lotus/cool_module.py b/lofar_lotus/cool_module.py deleted file mode 100644 index de82c46c186cce3880ff8f56d87ba4406ee195ce..0000000000000000000000000000000000000000 --- a/lofar_lotus/cool_module.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright (C) 2023 ASTRON (Netherlands Institute for Radio Astronomy) -# SPDX-License-Identifier: Apache-2.0 - -""" Cool module containing functions, classes and other useful things """ - - -def greeter(): - """Prints a nice message""" - print("Hello World!") diff --git a/lofar_lotus/dict/__init__.py b/lofar_lotus/dict/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d6443f090b109eddbfa3d39b448e6d37b1ba94c --- /dev/null +++ b/lofar_lotus/dict/__init__.py @@ -0,0 +1,9 @@ +# Copyright (C) 2025 ASTRON (Netherlands Institute for Radio Astronomy) +# SPDX-License-Identifier: Apache-2.0 + +"""Common classes used in station""" + +from ._case_insensitive_dict import CaseInsensitiveDict, ReversibleKeysView +from ._case_insensitive_string import CaseInsensitiveString + +__all__ = ["CaseInsensitiveDict", "CaseInsensitiveString", "ReversibleKeysView"] diff --git a/lofar_lotus/dict/_case_insensitive_dict.py b/lofar_lotus/dict/_case_insensitive_dict.py new file mode 100644 index 0000000000000000000000000000000000000000..a3f88cb31aa4f32c69b8abf1b19e19ad61f9d074 --- /dev/null +++ b/lofar_lotus/dict/_case_insensitive_dict.py @@ -0,0 +1,150 @@ +# Copyright (C) 2025 ASTRON (Netherlands Institute for Radio Astronomy) +# SPDX-License-Identifier: Apache-2.0 + +"""Provides a special dictionary with case-insensitive keys""" + +import abc +from collections import UserDict +from typing import List +from typing import Tuple +from typing import Union + +from ._case_insensitive_string import CaseInsensitiveString + + +def _case_insensitive_comprehend_keys(data: dict) -> List[CaseInsensitiveString]: + return [CaseInsensitiveString(key) for key in data] + + +def _case_insensitive_comprehend_items( + data: dict, +) -> List[Tuple[CaseInsensitiveString, any]]: + return [(CaseInsensitiveString(key), value) for key, value in data.items()] + + +class ReversibleIterator: + """Reversible iterator using instance of self method + + See real-python for yield iterator method: + https://realpython.com/python-reverse-list/#the-special-method-__reversed__ + """ + + def __init__(self, data: List, start: int, stop: int, step: int): + self.data = data + self.current = start + self.stop = stop + self.step = step + + def __iter__(self): + return self + + def __next__(self): + if self.current == self.stop: + raise StopIteration + + elem = self.data[self.current] + self.current += self.step + return elem + + def __reversed__(self): + return ReversibleIterator(self.data, self.stop, self.current, -1) + + +class AbstractReversibleView(abc.ABC): + """An abstract reversible view""" + + def __init__(self, data: UserDict): + self.data = data + self.len = len(data) + + def __repr__(self): + return f"{self.__class__.__name__}({self.data})" + + @abc.abstractmethod + def __iter__(self): + pass + + @abc.abstractmethod + def __reversed__(self): + pass + + +class ReversibleItemsView(AbstractReversibleView): + """Reversible view on items""" + + def __iter__(self): + return ReversibleIterator( + _case_insensitive_comprehend_items(self.data.data), 0, self.len, 1 + ) + + def __reversed__(self): + return ReversibleIterator( + _case_insensitive_comprehend_items(self.data.data), self.len - 1, -1, -1 + ) + + +class ReversibleKeysView(AbstractReversibleView): + """Reversible view on keys""" + + def __iter__(self): + return ReversibleIterator( + _case_insensitive_comprehend_keys(self.data.data), 0, self.len, 1 + ) + + def __reversed__(self): + return ReversibleIterator( + _case_insensitive_comprehend_keys(self.data.data), self.len - 1, -1, -1 + ) + + +class ReversibleValuesView(AbstractReversibleView): + """Reversible view on values""" + + def __iter__(self): + return ReversibleIterator(list(self.data.data.values()), 0, self.len, 1) + + def __reversed__(self): + return ReversibleIterator(list(self.data.data.values()), self.len - 1, -1, -1) + + +class CaseInsensitiveDict(UserDict): + """Special dictionary that ignores key casing if string + + While UserDict is the least performant / flexible it ensures __set_item__ and + __get_item__ are used in all code paths reducing LoC severely. + + Background reference: + https://realpython.com/inherit-python-dict/#creating-dictionary-like-classes-in-python + + Alternative (should this stop working at some point): + https://github.com/DeveloperRSquared/case-insensitive-dict/blob/main/case_insensitive_dict/case_insensitive_dict.py + """ + + def __setitem__(self, key, value): + if isinstance(key, str): + key = CaseInsensitiveString(key) + super().__setitem__(key, value) + + def __getitem__(self, key: Union[int, str]): + if isinstance(key, str): + key = CaseInsensitiveString(key) + return super().__getitem__(key) + + def __iter__(self): + return ReversibleIterator( + _case_insensitive_comprehend_keys(self.data), 0, len(self.data), 1 + ) + + def __contains__(self, key): + if isinstance(key, str): + key = CaseInsensitiveString(key) + return super().__contains__(key) + + def keys(self) -> ReversibleKeysView: + return ReversibleKeysView(self) + + def values(self) -> ReversibleValuesView: + return ReversibleValuesView(self) + + def items(self) -> ReversibleItemsView: + return ReversibleItemsView(self) diff --git a/lofar_lotus/dict/_case_insensitive_string.py b/lofar_lotus/dict/_case_insensitive_string.py new file mode 100644 index 0000000000000000000000000000000000000000..9471fdf56b55fea6cba5a80b0bee406d44e53cba --- /dev/null +++ b/lofar_lotus/dict/_case_insensitive_string.py @@ -0,0 +1,28 @@ +# Copyright (C) 2025 ASTRON (Netherlands Institute for Radio Astronomy) +# SPDX-License-Identifier: Apache-2.0 + +"""Special string that ignores casing in comparison""" + + +class CaseInsensitiveString(str): + """Special string that ignores casing in comparison""" + + def __eq__(self, other): + if isinstance(other, str): + return self.casefold() == other.casefold() + + return self.casefold() == other + + def __hash__(self): + return hash(self.__str__()) + + def __contains__(self, key): + if isinstance(key, str): + return key.casefold() in str(self) + return key in str(self) + + def __str__(self) -> str: + return self.casefold().__str__() + + def __repr__(self) -> str: + return self.casefold().__repr__() diff --git a/lofar_lotus/file_access/README.md b/lofar_lotus/file_access/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ab0e1e6a316a112f26c32acb74b1a284999ecc02 --- /dev/null +++ b/lofar_lotus/file_access/README.md @@ -0,0 +1,121 @@ +# HDF file reader + +## Define a model + +The data structure of the HDF file is defined by python objects using decorators. Currently, there are two decorators +available: + +1. `member`: defines a class property to be an HDF group or dataset depending on the type. +2. `attribute`: defines a class property to be an HDF attribute on a group or dataset. + +### Dataset definition + +A basic data structure to define the HDF file looks like this: + +```python +class Data: + list_of_ints: List[int] = member() + list_of_floats: List[float] = member() + numpy_array: ndarray = member() +``` + +It is important to always use type hints. It not only makes the classes more self-explanatory during development it is +also +important for the file reader to guesstimate the right action to perform. + +In this first example we only used arrays and lists. These types always map to a dataset within HDF. By default, +the reader is looking for a dataset with the name of the variable, if the dataset is named differently it can be +overwritten +by specifying the `name` parameter: `member(name='other_name_then_variable')`. Also, all members are required by +default. +If they don't appear in the HDF file an error is thrown. This behavior can be changed by specifying the `optional` +parameter: +`member(optional=True)`. + +### Group definition + +HDF supports to arrange the data in groups. Groups can be defined as additional classes: + +```python +class SubGroup: + list_of_ints: List[int] = member() + +class Data: + sub_group: SubGroup = member() + +``` + +Additionally, all additional settings apply in the same way as they do for datasets. + +### Dictionaries + +A special case is the `dict`. It allows to read a set of groups or datasets using the name of the group or dataset as +the key. + +```python +class Data: + data_dict: Dict[str, List[int]] = member() +``` + +### Attribute definition + +Attributes in a HDF file can appear on groups as well as on datasets and can be defined by using `attribute()`: + +```python +class Data: + an_attr: str = attribute() +``` + +The file reader will look for an attribute with the name `an_attr` on the group that is represented by the class `Data`. +The name of the attribute can be overwritten by specifying the `name` parameter: `attribute(name='other_name')`. All +attributes +are required by default and will cause an exception to be thrown if they are not available. This behavior can be changed +by specifying the `optional` parameter: +`attribute(optional=True)`. + +In HDF also datasets can contain attributes. Since they are usually mapped to primitive types it would not be possible +to access +these attributes. Therefor `attribute` allows to specify another member in the class by setting `from_member`. + +## Read a HDF file + +A file can be read using `read_hdf5`: + +```python +with read_hdf5('file_name.h5', Data) as data: + a = data.an_attr +``` + +## Create a HDF file + +A file can be created using `create_hdf5` - existing files will be overwritten: + +```python +with create_hdf5('file_name.h5', Data) as data: + data.an_attr = "data" +``` + +NB: + +1. Writes are cached until `flush()` is called or the file is closed. +2. Reading back attributes will read them from disk. + +## Change a HDF file + +A file can be changed using `open_hdf5` - the file must exist: + +```python +with open_hdf5('file_name.h5', Data) as data: + data.an_attr = "new value" +``` + +## Data write behaviour + +### members +All changes to members of the object are immediately written to the underlying HDF file. Therefore, altering the object +should be minimized to have no performance degradation. + +### attributes +Attributes are written if `flush()` is invoked on the `FileWriter` or when the `with` scope is exited. This behaviour is +necessary because attributes depend on the underlying members. Therefore, the attributes can only be written after +the members. diff --git a/lofar_lotus/file_access/__init__.py b/lofar_lotus/file_access/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d916df1d068fdf460e55a0914f9f1115afba420e --- /dev/null +++ b/lofar_lotus/file_access/__init__.py @@ -0,0 +1,24 @@ +# Copyright (C) 2022 ASTRON (Netherlands Institute for Radio Astronomy) +# SPDX-License-Identifier: Apache-2.0 + + +""" +Contains classes to interact with (hdf5) files +""" + +from ._attribute_def import attribute +from ._member_def import member +from ._readers import FileReader +from .hdf._hdf_readers import read_hdf5 +from .hdf._hdf_writers import open_hdf5, create_hdf5 +from ._writers import FileWriter + +__all__ = [ + "FileReader", + "FileWriter", + "attribute", + "member", + "read_hdf5", + "open_hdf5", + "create_hdf5", +] diff --git a/lofar_lotus/file_access/_attribute_def.py b/lofar_lotus/file_access/_attribute_def.py new file mode 100644 index 0000000000000000000000000000000000000000..ac15641b834b724a7aa0ebc44fbbefbfccdeddfc --- /dev/null +++ b/lofar_lotus/file_access/_attribute_def.py @@ -0,0 +1,76 @@ +# Copyright (C) 2022 ASTRON (Netherlands Institute for Radio Astronomy) +# SPDX-License-Identifier: Apache-2.0 + +""" +Contains HDF5 specific classes and methods to define class members as an HDF attribute +""" + +from typing import Any, Type + +from ._readers import DataReader +from ._utils import _extract_type +from ._writers import DataWriter + + +def attribute(name: str = None, optional: bool = False, from_member: str = None): + """ + Define a class member as an attribute within a HDF5 file + """ + return AttributeDef(name, optional, from_member) + + +# pylint: disable=too-few-public-methods +class AttributeDef: + """ + Decorator to extract attributes of HDF5 groups and datasets to pythonic objects + """ + + def __init__(self, name: str, optional: bool, from_member: str = None): + self.name = name + self.property_name: str + self.from_member = from_member + self.optional = optional + self.owner: Any + self.type: Type + + def __set_name__(self, owner, name): + if self.name is None: + self.name = name + self.property_name = name + self.owner = owner + self.type = _extract_type(owner, name) + + def __set__(self, instance, value): + setattr(instance, self.attr_name, value) + + if hasattr(instance, "_data_writer"): + writer: DataWriter = getattr(instance, "_data_writer") + writer.write_attribute( + instance, self.name, self.owner, self.from_member, self.optional, value + ) + + def __get__(self, instance, obj_type=None): + if instance is None: + # attribute is accessed as a class attribute + return self + + if hasattr(instance, self.attr_name): + return getattr(instance, self.attr_name) + + if hasattr(instance, "_data_reader"): + reader: DataReader = getattr(instance, "_data_reader") + attr = reader.read_attribute( + self.name, self.owner, self.from_member, self.optional + ) + setattr(instance, self.attr_name, attr) + return attr + return None + + @property + def attr_name(self): + """ + Name used to store the value in the owning object + """ + if self.from_member is None: + return f"_a_{self.name}" + return f"_a_{self.from_member}_{self.name}" diff --git a/lofar_lotus/file_access/_lazy_dict.py b/lofar_lotus/file_access/_lazy_dict.py new file mode 100644 index 0000000000000000000000000000000000000000..046ff0c9c7313b535baf58308370aaca80e718e0 --- /dev/null +++ b/lofar_lotus/file_access/_lazy_dict.py @@ -0,0 +1,72 @@ +# Copyright (C) 2023 ASTRON (Netherlands Institute for Radio Astronomy) +# SPDX-License-Identifier: Apache-2.0 + +""" +Provides a dictionary that dynamically resolves its values to reduce memory usage +""" + +from abc import abstractmethod +from typing import TypeVar, Dict, Type + +K = TypeVar("K") +V = TypeVar("V") + + +class LazyDict: + """ + Lazy evaluated dictionary + """ + + @abstractmethod + def setup_write(self, writer): + """ + Set up the lazy dict to support write actions + """ + + @classmethod + def __subclasshook__(cls, subclass): + return ( + hasattr(subclass, "setup_write") + and callable(subclass.setup_write) + or NotImplemented + ) + + +def lazy_dict(base_dict: Type[Dict[K, V]], reader): + """ + Dynamically derive lazy dict of given type + """ + + class LazyDictImpl(base_dict, LazyDict): + """ + Implementation of the lazy dict dynamically derived from base dict + """ + + def __init__(self, reader, *args, **kwargs): + super().__init__(*args, **kwargs) + self._reader = reader + self._writer = None + + def __setitem__(self, item, value): + if callable(value): + super().__setitem__(item, value) + return + + # write value somewhere + if self._writer is not None: + self._writer(item, value) + + super().__setitem__(item, lambda: self._reader(item)) + + def __getitem__(self, item): + return super().__getitem__(item)() + + def items(self): + """D.items() -> a set-like object providing a view on D's items""" + for key, value in super().items(): + yield key, value() + + def setup_write(self, writer): + self._writer = writer + + return LazyDictImpl(reader) diff --git a/lofar_lotus/file_access/_member_def.py b/lofar_lotus/file_access/_member_def.py new file mode 100644 index 0000000000000000000000000000000000000000..5d31f6d462611c026e1733372b28b410c5afecb6 --- /dev/null +++ b/lofar_lotus/file_access/_member_def.py @@ -0,0 +1,72 @@ +# Copyright (C) 2022 ASTRON (Netherlands Institute for Radio Astronomy) +# SPDX-License-Identifier: Apache-2.0 + +""" +Contains HDF5 specific classes and methods to define class members as members +of HDF5 files +""" + +from typing import Type + +from ._readers import DataReader +from ._utils import _extract_type +from ._writers import DataWriter + + +def member(name: str = None, optional: bool = False, compression: str = None): + """ + Define a class member as a member of a HDF5 file + """ + return MemberDef(name, optional, compression) + + +# pylint: disable=too-few-public-methods +class MemberDef: + """ + Decorator to handle the transformation of HDF5 groups + and datasets to pythonic objects + """ + + def __init__(self, name: str, optional: bool, compression: str): + self.name = name + self.property_name: str + self.optional = optional + self.compression = compression + self.type: Type + + def __set_name__(self, owner, name): + if self.name is None: + self.name = name + self.property_name = name + self.type = _extract_type(owner, name) + + def __get__(self, instance, obj_type=None): + if instance is None: + # attribute is accessed as a class attribute + return self + + if hasattr(instance, "_data_reader"): + reader: DataReader = getattr(instance, "_data_reader") + return reader.read_member(instance, self.name, self.type, self.optional) + + if hasattr(instance, self.attr_name): + return getattr(instance, self.attr_name) + return None + + def __set__(self, instance, value): + if not hasattr(instance, "_data_writer"): + setattr(instance, self.attr_name, value) + return + + writer: DataWriter = getattr(instance, "_data_writer") + writer.write_member(self.name, self.type, value) + + if hasattr(instance, self.attr_name): + delattr(instance, self.attr_name) + + @property + def attr_name(self): + """ + Name used to store the value in the owning object + """ + return f"_v_{self.name}" diff --git a/lofar_lotus/file_access/_monitoring.py b/lofar_lotus/file_access/_monitoring.py new file mode 100644 index 0000000000000000000000000000000000000000..8e1b1f9be845123a81b3ddfb1cde9fe7ec55b066 --- /dev/null +++ b/lofar_lotus/file_access/_monitoring.py @@ -0,0 +1,49 @@ +# Copyright (C) 2022 ASTRON (Netherlands Institute for Radio Astronomy) +# SPDX-License-Identifier: Apache-2.0 + +""" +Class wrappers for lists and dictionaries monitoring changes of itself and notifying +the registered event handler about these changes. +""" + +from typing import Any + + +class MonitoredWrapper: + """ + A wrapper monitoring changes of itself and notifying the registered event handler + about changes. + """ + + def __init__(self, event, instance): + self._event = event + self._instance = instance + + def __setitem__(self, key, value): + self._instance.__setitem__(key, value) + self._event(self._instance) + + def __getitem__(self, item): + return self._instance.__getitem__(item) + + def __setattr__(self, name: str, value: Any) -> None: + if name in ["_instance", "_event"]: + object.__setattr__(self, name, value) + else: + self._instance.__setattr__(name, value) + self._event(self._instance) + + def __getattribute__(self, name): + if name in ["_instance", "_event"]: + return object.__getattribute__(self, name) + attr = object.__getattribute__(self._instance, name) + if hasattr(attr, "__call__"): + + def wrapper(*args, **kwargs): + result = attr(*args, **kwargs) + self._event(self._instance) + return result + + return wrapper + + return attr diff --git a/lofar_lotus/file_access/_readers.py b/lofar_lotus/file_access/_readers.py new file mode 100644 index 0000000000000000000000000000000000000000..b3e0b2a879466567c682f27b9075911841197fd2 --- /dev/null +++ b/lofar_lotus/file_access/_readers.py @@ -0,0 +1,64 @@ +# Copyright (C) 2022 ASTRON (Netherlands Institute for Radio Astronomy) +# SPDX-License-Identifier: Apache-2.0 + +""" +Contains classes to handle reading +""" + +from abc import ABC, abstractmethod +from typing import TypeVar, Generic + +T = TypeVar("T") + + +class FileReader(Generic[T], ABC): + """ + Abstract file reader + """ + + @abstractmethod + def read(self) -> T: + """ + Read the opened file into a pythonic representation specified by target_type. + Will automatically figure out if target_type is a dict or a regular object + """ + + @abstractmethod + def close(self): + """ + Close the underlying file + """ + + def load(self, instance: T): + """ + Load all the data from the underlying HDF file + to preserve it in the objects after closing the + file. + """ + + def __enter__(self): + return self.read() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def __del__(self): + self.close() + + +class DataReader(ABC): + """ + Abstract data reader + """ + + @abstractmethod + def read_member(self, obj, name: str, target_type, optional: bool): + """ + Read given member from underlying file + """ + + @abstractmethod + def read_attribute(self, name, owner, from_member, optional): + """ + Read given attribute from underlying file + """ diff --git a/lofar_lotus/file_access/_utils.py b/lofar_lotus/file_access/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..58b4617f5aee6dabbed903e4aa43c1cb552036e6 --- /dev/null +++ b/lofar_lotus/file_access/_utils.py @@ -0,0 +1,35 @@ +# Copyright (C) 2024 ASTRON (Netherlands Institute for Radio Astronomy) +# SPDX-License-Identifier: Apache-2.0 + +""" +General utils +""" + +from typing import Optional, Type, get_type_hints, get_args, get_origin + +from numpy import ndarray + +from ._monitoring import MonitoredWrapper + + +def _extract_type(owner: object, name: str) -> Optional[Type]: + type_hints = get_type_hints(owner) + return type_hints[name] if name in type_hints else None + + +def _extract_base_type(target_type: Type): + args = get_args(target_type) + if len(args) >= 2: + return args[1] + + return [ + get_args(b)[1] for b in target_type.__orig_bases__ if get_origin(b) is dict + ][0] + + +def _wrap(target_type, value, callback): + if get_origin(target_type) is list: + return MonitoredWrapper(callback, value) + if target_type is ndarray: + return MonitoredWrapper(callback, value) + return value diff --git a/lofar_lotus/file_access/_writers.py b/lofar_lotus/file_access/_writers.py new file mode 100644 index 0000000000000000000000000000000000000000..25529ba73aad76a2d97cdbb7acd1b9ae96a79dda --- /dev/null +++ b/lofar_lotus/file_access/_writers.py @@ -0,0 +1,58 @@ +# Copyright (C) 2023 ASTRON (Netherlands Institute for Radio Astronomy) +# SPDX-License-Identifier: Apache-2.0 + +""" +Contains classes to handle file writing +""" + +from abc import ABC, abstractmethod +from typing import TypeVar + +from ._readers import FileReader, DataReader + +T = TypeVar("T") + + +class FileWriter(FileReader[T], ABC): + """ + Abstract file writer + """ + + def __init__(self, create): + self._create = create + + @abstractmethod + def create(self) -> T: + """ + Create the object representing the file + """ + + @abstractmethod + def open(self) -> T: + """ + Create the object representing the file + """ + + def __enter__(self): + if self._create: + return self.create() + return self.open() + + +class DataWriter(DataReader, ABC): + """ + Abstract data writer + """ + + @abstractmethod + def write_member(self, name: str, target_type, value): + """ + Write given member to underlying file + """ + + @abstractmethod + # pylint: disable=too-many-arguments,too-many-positional-arguments + def write_attribute(self, instance, name, owner, from_member, optional, value): + """ + Write given attribute to underlying file + """ diff --git a/lofar_lotus/file_access/hdf/__init__.py b/lofar_lotus/file_access/hdf/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c92b615444d854a6e87370b16cf733a5859a07e7 --- /dev/null +++ b/lofar_lotus/file_access/hdf/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2023 ASTRON (Netherlands Institute for Radio Astronomy) +# SPDX-License-Identifier: Apache-2.0 diff --git a/lofar_lotus/file_access/hdf/_hdf5_utils.py b/lofar_lotus/file_access/hdf/_hdf5_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..52eeedfa8d626b6be55bd40b8e950a3da748b0b3 --- /dev/null +++ b/lofar_lotus/file_access/hdf/_hdf5_utils.py @@ -0,0 +1,51 @@ +# Copyright (C) 2025 ASTRON (Netherlands Institute for Radio Astronomy) +# SPDX-License-Identifier: Apache-2.0 + +""" +Utils to handle transformation of HDF5 specific classes to pythonic objects +""" + +from collections.abc import MutableMapping +from inspect import get_annotations, getattr_static +from typing import Type, TypeVar, get_origin + +from numpy import ndarray + +T = TypeVar("T") + + +def _assert_is_dataset(value): + if issubclass(type(value), MutableMapping): + raise TypeError( + f"Only <Dataset> can be mappet do primitive type while " + f"value is of type <{type(value).__name__}>" + ) + + +def _assert_is_group(value): + if not issubclass(type(value), MutableMapping): + raise TypeError( + "Only Group can be mapped to <object> while value" + f" is of type <{type(value).__name__}>" + ) + + +def _is_attachable(target_type: Type[T]): + origin_type = get_origin(target_type) + if origin_type is dict: + return False + if get_origin(target_type) is list: + return False + if target_type is ndarray: + return False + return True + + +def _attach_object(target_type: Type[T], instance): + for cls in target_type.mro(): + annotations = get_annotations(cls) + + for annotation in annotations: + attr = getattr_static(target_type, annotation) + if hasattr(instance, attr.attr_name): + setattr(instance, attr.property_name, getattr(instance, attr.attr_name)) diff --git a/lofar_lotus/file_access/hdf/_hdf_readers.py b/lofar_lotus/file_access/hdf/_hdf_readers.py new file mode 100644 index 0000000000000000000000000000000000000000..cfc658d3e63166fdcc05be565b108dbf09bc5af5 --- /dev/null +++ b/lofar_lotus/file_access/hdf/_hdf_readers.py @@ -0,0 +1,203 @@ +# Copyright (C) 2024 ASTRON (Netherlands Institute for Radio Astronomy) +# SPDX-License-Identifier: Apache-2.0 + +""" +Contains classes to handle file reading +""" + +import inspect +import weakref +from inspect import getattr_static +from typing import TypeVar, Type, List, Dict, get_origin + +import h5py +from numpy import ndarray, zeros + +from ._hdf5_utils import ( + _assert_is_group, + _assert_is_dataset, +) +from .._attribute_def import AttributeDef +from .._lazy_dict import lazy_dict +from .._member_def import MemberDef +from .._readers import FileReader, DataReader +from .._utils import _extract_base_type + +T = TypeVar("T") + + +class HdfFileReader(FileReader[T]): + """ + HDF5 specific file reader + """ + + def __init__(self, name, target_type): + self.file_name = name + self._is_closed = None + self._target_type = target_type + self._open_file(name) + self._references: List[weakref] = [] + + def _open_file(self, name): + self._hdf5_file = h5py.File(name, "r") + self._is_closed = False + + def read(self) -> T: + """ + Read the opened file into a pythonic representation specified by target_type. + Will automatically figure out if target_type is a dict or a regular object + """ + reader = HdfDataReader.detect_reader( + self._target_type, HdfDataReader(self, self._hdf5_file) + ) + obj = reader(self._hdf5_file) + return obj + + def close(self): + """ + Close the underlying HDF file + """ + for ref in self._references: + obj = ref() + if obj is not None: + self._detach_object(obj) + self._references = [] + + if not self._is_closed: + self._is_closed = True + self._hdf5_file.close() + del self._hdf5_file + + def load(self, instance: T): + """ + Load all the data from the underlying HDF file + to preserve it in the objects after closing the + file. + """ + self._references.append(weakref.ref(instance)) + target_type = type(instance) + for annotation in [ + m[0] for m in inspect.getmembers(instance) if not m[0].startswith("_") + ]: + attr = inspect.getattr_static(target_type, annotation) + if isinstance(attr, (MemberDef, AttributeDef)): + setattr(instance, attr.attr_name, getattr(instance, attr.property_name)) + + def _detach_object(self, instance): + if not hasattr(instance, "_data_reader"): + return + delattr(instance, "_data_reader") + for attr in [ + m[0] + for m in inspect.getmembers(instance) + if not m[0].startswith("_") and m[0] != "T" + ]: + item = getattr(instance, attr) + item_type = type(item) + if ( + item is not None + and item is object + and not (item_type is ndarray or item_type is str) + ): + self._detach_object(item) + + +class HdfDataReader(DataReader): + """ + HDF data reader + """ + + def __init__(self, file_reader: HdfFileReader, data): + self.file_reader = file_reader + self.data = data + + def read_member(self, obj, name, target_type, optional): + if name not in self.data: + if optional: + return None + raise KeyError(f"Could not find required key {name}") + + reader = self.detect_reader( + target_type, self.__class__(self.file_reader, self.data[name]) + ) + return reader(self.data[name]) + + def read_attribute(self, name, owner, from_member, optional): + attrs: dict + if from_member is None: + attrs = self.data.attrs + else: + member = getattr_static(owner, from_member) + attrs = self.data[member.name].attrs + + if name not in attrs: + if optional: + return None + raise KeyError(f"Could not find required attribute key {name}") + + return attrs[name] + + @classmethod + def _read_object( + cls, target_type: Type[T], value, file_reader: "HdfDataReader" + ) -> T: + _assert_is_group(value) + obj = target_type() + setattr(obj, "_data_reader", cls(file_reader.file_reader, value)) + return obj + + @staticmethod + def _read_list(value): + _assert_is_dataset(value) + return list(value[:]) + + @classmethod + def _read_ndarray(cls, target_type: Type[T], value, file_reader: "HdfDataReader"): + _assert_is_dataset(value) + nd_value = zeros(value.shape, value.dtype) + # convert the data set to a numpy array + value.read_direct(nd_value) + if target_type is ndarray: + return nd_value + obj = nd_value.view(target_type) + setattr(obj, "_data_reader", cls(file_reader.file_reader, value)) + return obj + + @classmethod + def _read_dict( + cls, target_type: Type[T], value, dict_type, data_reader: "HdfDataReader" + ) -> Dict[str, T]: + reader = cls.detect_reader(target_type, data_reader) + result = lazy_dict(dict_type, lambda k: reader(value[k])) + for k in value.keys(): + result[k] = lambda n=k: reader(value[n]) + if dict_type is not dict: + setattr(result, "_data_reader", cls(data_reader.file_reader, value)) + return result + + @classmethod + def detect_reader(cls, target_type, data_reader: "HdfDataReader"): + """ + Detect the required reader based on expected type + """ + origin_type = get_origin(target_type) + if origin_type is dict: + return lambda value: cls._read_dict( + _extract_base_type(target_type), value, dict, data_reader + ) + if get_origin(target_type) is list: + return cls._read_list + if issubclass(target_type, ndarray): + return lambda value: cls._read_ndarray(target_type, value, data_reader) + if issubclass(target_type, dict): + return lambda value: cls._read_dict( + _extract_base_type(target_type), value, target_type, data_reader + ) + return lambda value: cls._read_object(target_type, value, data_reader) + + +def read_hdf5(name, target_type: Type[T]) -> FileReader[T]: + """ + Open a HDF5 file by name/path + """ + return HdfFileReader[T](name, target_type) diff --git a/lofar_lotus/file_access/hdf/_hdf_writers.py b/lofar_lotus/file_access/hdf/_hdf_writers.py new file mode 100644 index 0000000000000000000000000000000000000000..ab6ea6396a51d1a759b2687b56cea19718712ec2 --- /dev/null +++ b/lofar_lotus/file_access/hdf/_hdf_writers.py @@ -0,0 +1,295 @@ +# Copyright (C) 2024 ASTRON (Netherlands Institute for Radio Astronomy) +# SPDX-License-Identifier: Apache-2.0 + +""" +Contains classes to handle file writing +""" + +from inspect import getattr_static +from typing import TypeVar, Type, Dict, get_origin + +import h5py +from numpy import ndarray + +from ._hdf5_utils import ( + _is_attachable, + _attach_object, + _assert_is_group, + _assert_is_dataset, +) +from ._hdf_readers import HdfFileReader, HdfDataReader +from .._lazy_dict import LazyDict +from .._utils import _wrap, _extract_base_type +from .._writers import FileWriter, DataWriter + +T = TypeVar("T") + + +class HdfFileWriter(HdfFileReader[T], FileWriter[T]): + """ + HDF5 specific file writer + """ + + def __init__(self, name, target_type, create): + self._create = create + self.writers: list[HdfDataWriter] = [] + super().__init__(name, target_type) + + def _open_file(self, name): + self._hdf5_file = h5py.File(name, "w" if self._create else "a") + self._is_closed = False + + def flush(self): + """ + Flush all registered writers + """ + for writer in self.writers: + writer.flush() + self.writers = [] + + if not self._is_closed: + self._hdf5_file.flush() + + def close(self): + self.flush() + super().close() + + def open(self) -> T: + return self.create() + + def create(self) -> T: + """ + Create the object representing the HDF file + """ + data_writer = HdfDataWriter(self, self._hdf5_file) + reader = HdfDataWriter.detect_reader(self._target_type, data_writer) + obj = reader(self._hdf5_file) + if isinstance(obj, dict): + obj = _wrap( + self._target_type, + obj, + lambda value: HdfDataWriter.write_dict( + self._target_type, + self._hdf5_file, + value, + data_writer, + ), + ) + try: + setattr(obj, "_data_writer", data_writer) + except AttributeError: + pass + return obj + + +class HdfDataWriter(HdfDataReader, DataWriter): + """ + HDF data writer + """ + + def read_member(self, obj, name, target_type, optional): + instance = super().read_member(obj, name, target_type, optional) + + return _wrap( + target_type, + instance, + lambda a: setattr(obj, name, a), + ) + + @classmethod + def _read_dict( + cls, target_type: Type[T], value, dict_type, data_reader: "HdfDataWriter" + ) -> Dict[str, T]: + obj = super()._read_dict(target_type, value, dict_type, data_reader) + data_writer = cls(data_reader.file_writer, value) + if dict_type is not dict: + setattr(obj, "_data_writer", data_writer) + if isinstance(obj, LazyDict): + obj.setup_write( + lambda k, v: cls.write_dict_member( + target_type, value, k, v, data_writer + ) + ) + return obj + + @classmethod + def _read_object( + cls, target_type: Type[T], value, file_reader: "HdfDataWriter" + ) -> T: + obj = super()._read_object(target_type, value, file_reader) + setattr(obj, "_data_writer", cls(file_reader.file_writer, value)) + return obj + + def __init__(self, file_writer: HdfFileWriter, data): + self.file_writer = file_writer + self.file_writer.writers.append(self) + self.data = data + self.write_actions = [] + super().__init__(file_writer, data) + super(HdfDataReader, self).__init__() + + def write_member(self, name: str, target_type: Type[T], value): + data = self.data + writer = self.detect_writer(target_type, self) + writer(data, name, value) + + if _is_attachable(target_type): + _attach_object(target_type, value) + + def flush(self): + """ + Executed all pending write actions + """ + for action in self.write_actions: + action() + + # pylint: disable=too-many-arguments,too-many-positional-arguments + def write_attribute(self, instance, name, owner, from_member, optional, value): + self.write_actions.append( + lambda: self._write_attribute(name, owner, from_member, value) + ) + + def _write_attribute(self, name, owner, from_member, value): + attrs = self._resolve_attrs(owner, from_member) + + try: + attrs[name] = value + except (RuntimeError, TypeError) as exc: + raise ValueError( + f"Failed to write to attribute {self.data.name}.{name}" + ) from exc + + def _resolve_attrs(self, owner, from_member): + """ + Finds the right attribute to write into + """ + if from_member is None: + return self.data.attrs + + member = getattr_static(owner, from_member) + return self.data[member.name].attrs + + @classmethod + def detect_writer(cls, target_type, data_writer: "HdfDataWriter"): + """ + Detect required writer based on expected type + """ + origin_type = get_origin(target_type) + if origin_type is dict: + return lambda data, key, value: cls._write_dict_group( + target_type, data, key, value, data_writer + ) + if get_origin(target_type) is list: + return lambda data, key, value: cls._write_ndarray( + list, data, key, value, data_writer + ) + if target_type is ndarray or issubclass(target_type, ndarray): + return lambda data, key, value: cls._write_ndarray( + target_type, data, key, value, data_writer + ) + if issubclass(target_type, dict): + return lambda data, key, value: cls._write_dict_group( + target_type, data, key, value, data_writer + ) + return lambda data, key, value: cls._write_object( + target_type, data, key, value, data_writer + ) + + @classmethod + def _write_ndarray( + cls, target_type: Type[T], data, key, value, data_writer: "HdfDataWriter" + ): + _assert_is_group(data) + if key in data: + _assert_is_dataset(data[key]) + del data[key] + + # GZIP filter ("gzip"). Available with every installation of HDF5. + # compression_opts sets the compression level and may be an integer from 0 to 9, + # default is 4. + # https://docs.h5py.org/en/stable/high/dataset.html#lossless-compression-filters + data.create_dataset(key, data=value, compression="gzip", compression_opts=9) + if target_type is not ndarray and issubclass(target_type, ndarray): + data_writer = cls(data_writer.file_writer, data[key]) + setattr(value, "_data_writer", data_writer) + setattr(value, "_data_reader", data_writer) + _attach_object(target_type, value) + + @classmethod + # pylint: disable=too-many-arguments,too-many-positional-arguments + def _write_dict_group( + cls, target_type: Type[T], data, key, value, data_writer: "HdfDataWriter" + ): + _assert_is_group(data) + if key not in data: + data.create_group(key) + + try: + data_writer = cls(data_writer.file_writer, data[key]) + setattr(value, "_data_writer", data_writer) + setattr(value, "_data_reader", data_writer) + _attach_object(target_type, value) + except AttributeError: + pass + + cls.write_dict( + target_type, data[key], value, cls(data_writer.file_writer, data[key]) + ) + + @classmethod + def write_dict( + cls, target_type: Type[T], data, value, data_writer: "HdfDataWriter" + ): + """ + Write given dictionary to given data group + """ + _assert_is_group(data) + for k in data.keys(): + if k not in value: + del data[k] + writer = HdfDataWriter.detect_writer( + _extract_base_type(target_type), data_writer + ) + + for k in value.keys(): + writer(data, k, value[k]) + + @classmethod + def write_dict_member( + cls, target_type: Type[T], data, key, value, data_writer: "HdfDataWriter" + ): + """ + Write single given dictionary member to given data group + """ + _assert_is_group(data) + writer = HdfDataWriter.detect_writer(target_type, data_writer) + writer(data, key, value) + + @classmethod + # pylint: disable=too-many-arguments,too-many-positional-arguments + def _write_object( + cls, target_type: Type[T], data, key, value: T, data_writer: "HdfDataWriter" + ): + _assert_is_group(data) + if key in data: + _assert_is_group(data[key]) + else: + data.create_group(key) + data_writer = cls(data_writer.file_writer, data[key]) + setattr(value, "_data_writer", data_writer) + setattr(value, "_data_reader", data_writer) + _attach_object(target_type, value) + + +def open_hdf5(name, target_type: Type[T]) -> FileWriter[T]: + """ + Open a HDF5 file by name/path + """ + return HdfFileWriter[T](name, target_type, False) + + +def create_hdf5(name, target_type: Type[T]) -> FileWriter[T]: + """ + Create a HDF5 file by name/path + """ + return HdfFileWriter[T](name, target_type, True) diff --git a/lofar_lotus/zeromq/__init__.py b/lofar_lotus/zeromq/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d52101386f5ec03e25dcfab9877fbbf4d593f917 --- /dev/null +++ b/lofar_lotus/zeromq/__init__.py @@ -0,0 +1,9 @@ +# Copyright (C) 2025 ASTRON (Netherlands Institute for Radio Astronomy) +# SPDX-License-Identifier: Apache-2.0 + +"""Classes to communicate with ZeroMQ""" + +from ._publisher import ZeroMQPublisher +from ._subscriber import ZeroMQSubscriber, AsyncZeroMQSubscriber + +__all__ = ["ZeroMQSubscriber", "AsyncZeroMQSubscriber", "ZeroMQPublisher"] diff --git a/lofar_lotus/zeromq/_pipe.py b/lofar_lotus/zeromq/_pipe.py new file mode 100644 index 0000000000000000000000000000000000000000..30fb40c719b0392fd14580f4d9ed2f29f2543b50 --- /dev/null +++ b/lofar_lotus/zeromq/_pipe.py @@ -0,0 +1,27 @@ +# Copyright (C) 2024 ASTRON (Netherlands Institute for Radio Astronomy) +# SPDX-License-Identifier: Apache-2.0 + +"""Construct a ZMQ socket pair forming a pipe.""" + +import binascii +import os +from typing import Tuple + +import zmq + + +def zpipe(ctx) -> Tuple[zmq.Socket, zmq.Socket]: + """build inproc pipe for talking to threads + + mimic pipe used in czmq zthread_fork. + + Returns a pair of PAIRs connected via inproc + """ + a = ctx.socket(zmq.PAIR) + b = ctx.socket(zmq.PAIR) + a.linger = b.linger = 0 + a.hwm = b.hwm = 1 + iface = f"inproc://{binascii.hexlify(os.urandom(8))}" + a.bind(iface) + b.connect(iface) + return a, b diff --git a/lofar_lotus/zeromq/_publisher.py b/lofar_lotus/zeromq/_publisher.py new file mode 100644 index 0000000000000000000000000000000000000000..e0623425bc9dec5baa96c3657c0d7ec1d3a89c8c --- /dev/null +++ b/lofar_lotus/zeromq/_publisher.py @@ -0,0 +1,177 @@ +# Copyright (C) 2025 ASTRON (Netherlands Institute for Radio Astronomy) +# SPDX-License-Identifier: Apache-2.0 + +"""Base class for ZMQ publishers""" + +import logging +import queue +from concurrent.futures import Future +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +from datetime import timezone +from typing import Callable, Optional + +import zmq +from zmq import Socket + +logger = logging.getLogger() + +__all__ = ["ZeroMQPublisher"] + + +class ZeroMQPublisher: # pylint: disable=too-many-instance-attributes + """Base class for ZMQ publishers""" + + def __init__( + self, + bind_uri: str, + topics: list[bytes | str], + queue_size: int = 100, + ): + """ + param bind_uri: uri to bind of pattern protocol://ip:port + param topics: List of topics to publish to, for bytearray use str.encode() + """ + # define variables early in case __del__ gets called after an + # exception in __init__ + self._thread = None + + self._queue = queue.Queue(maxsize=queue_size) + self._ctx = zmq.Context.instance() + self._publisher = self._ctx.socket(zmq.PUB) + + if isinstance(topics, list) and all(isinstance(y, str) for y in topics): + self._topics = [topic.encode() for topic in topics] + else: + self._topics = topics + + self._publisher.bind(bind_uri) + self._is_running = False + self._is_stopping = False + self._thread = ThreadPoolExecutor(max_workers=1) + self._future = self._thread.submit(self._run) + + def __del__(self): + self.shutdown() + + @staticmethod + def construct_bind_uri(protocol: str, bind: str, port: str | int) -> str: + """Combine parameters into a full bind uri for ZeroMQ""" + if isinstance(port, int): + port = str(port) + return f"{protocol}://{bind}:{port}" + + @property + def is_stopping(self): + """If the request has been made to stop the publisher + + Remains true even after fully stopping + """ + return self._is_stopping + + @property + def is_running(self): + """If the publisher has started""" + # don't use self._future.is_running, returns false if thread sleeps ;) + return self._is_running and not self.is_done + + @property + def is_done(self) -> bool: + """If the publisher has fully stopped""" + return self._future.done() + + @property + def topics(self) -> [bytes]: + """Returns the topics ZMQ is publishing to""" + return self._topics + + @property + def publisher(self) -> Socket: + """Returns ZMQ publisher socket""" + return self._publisher + + def get_result(self, timeout=None) -> object: + """Return the returned result of the publisher. + + If the publisher threw an exception, it will be raised here.""" + + return self._future.result(timeout=timeout) + + def get_exception(self, timeout=None) -> Optional[Exception]: + """Return the exception the exeption raised by the publisher, or None.""" + + return self._future.exception(timeout=timeout) + + def register_callback(self, fn: Callable[[Future], None]): + """Register a callback to run when the publisher finishes.""" + + self._future.add_done_callback(fn) + + @property + def queue_fill(self) -> int: + """Return the number of items in the queue.""" + + return self._queue.qsize() + + @property + def queue_size(self) -> int: + """Return the maximum number of items that fit in the queue.""" + + return self._queue.maxsize + + def _run(self): + """Run the publishing thread.""" + + self._is_running = True + logger.info("Publisher thread: %s starting", self) + while not self._is_stopping: + try: + msg = self._queue.get(timeout=1) + try: + now = datetime.now().astimezone(tz=timezone.utc).isoformat() + for topic in self._topics: + logger.debug( + "Publisher send message with payload of size: %s", len(msg) + ) + msg = [topic, now.encode("utf-8"), f"{msg}".encode("utf-8")] + self._publisher.send_multipart(msg) + finally: + self._queue.task_done() + except queue.Empty: + logger.debug("Queue is empty, nothing to publish") + continue + except zmq.ZMQError as e: + if e.errno != zmq.ETERM: + self._stop() + raise e + except KeyboardInterrupt as e: + self._stop() + raise e + self._stop() + + def _stop(self): + """Internal function to handle stopping""" + self._publisher.close() + logger.info("Terminated thread of %s", self) + self._is_running = False + + def shutdown(self): + """External function to request stopping / shutdown""" + logger.debug("Request to stop thread of %s", self) + self._is_stopping = True + + if self._thread: + self._thread.shutdown() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.shutdown() + + def send(self, msg): + """ + param msg: The message to enqueue for transmission + raises queue.Full: If the message could not be enqueued + """ + self._queue.put_nowait(msg) diff --git a/lofar_lotus/zeromq/_subscriber.py b/lofar_lotus/zeromq/_subscriber.py new file mode 100644 index 0000000000000000000000000000000000000000..14038dc019283a908b97961ff2f0354ef7c85982 --- /dev/null +++ b/lofar_lotus/zeromq/_subscriber.py @@ -0,0 +1,211 @@ +# Copyright (C) 2025 ASTRON (Netherlands Institute for Radio Astronomy) +# SPDX-License-Identifier: Apache-2.0 + +"""Base class for ZMQ subscribers""" + +import asyncio +import logging +from concurrent.futures import CancelledError +from contextlib import suppress +from datetime import datetime +from threading import Thread +from typing import Any + +import zmq +import zmq.asyncio +from zmq.utils.monitor import recv_monitor_message + +logger = logging.getLogger() + +__all__ = ["ZeroMQSubscriber", "AsyncZeroMQSubscriber"] + + +class ZeroMQSubscriber: + """Base class for ZMQ subscribers. Usage: + + with ZeroMQSubscriber("tcp://host:port", ["topic"]) as subscriber: + (topic, timestamp, message) = subscriber.recv() + """ + + # pylint: disable=too-many-instance-attributes + + def __init__(self, connect_uri: str, topics: list[bytes | str]): + """ + + param connect_uri: uri of pattern protocol://fqdn:port + param topics: List of topics to subscribe to, must be bytearray use str.encode() + """ + self._ctx = self._new_zmq_context() + self._subscriber = self._ctx.socket(zmq.SUB) + self._thread = None + + self._connect_uri = connect_uri + self.nr_connects = 0 + self.nr_disconnects = 0 + self.is_connected = False + + if isinstance(topics, list) and all(isinstance(y, str) for y in topics): + self._topics = [topic.encode() for topic in topics] + else: + self._topics = topics + + # create monitoring socket to catch all events from the start + self.monitor = self._subscriber.get_monitor_socket() + + # subscribe + self._subscriber.connect(connect_uri) + for topic in self._topics: + self._subscriber.setsockopt(zmq.SUBSCRIBE, topic) + + @staticmethod + def _new_zmq_context(): + """Return a new ZMQ Context""" + return zmq.Context.instance() + + def __repr__(self): + return f"{self.__class__.__name__}({self._connect_uri}, {self._topics})" + + def _handle_event(self, evt: dict[str, Any]): + """Process a single monitor event.""" + + if evt["event"] == zmq.EVENT_HANDSHAKE_SUCCEEDED: + logger.info("ZeroMQ connected: %s", self) + self.nr_connects += 1 + self.is_connected = True + elif evt["event"] == zmq.EVENT_DISCONNECTED: + logger.warning("ZeroMQ disconnected: %s", self) + self.nr_disconnects += 1 + self.is_connected = False + + def _event_monitor_thread(self): + """Thread running the event monitor.""" + + logger.info("ZeroMQ event monitor started: %s", self) + + try: + while self.monitor.poll(): + evt = recv_monitor_message(self.monitor) + if evt["event"] == zmq.EVENT_MONITOR_STOPPED: + break + + self._handle_event(evt) + except Exception: + logger.exception("Error in ZeroMQ event monitor: %s", self) + raise + finally: + logger.info("ZeroMQ event monitor stopped: %s", self) + + @staticmethod + def _process_multipart( + multipart: tuple[bytes, bytes, bytes], + ) -> tuple[str, datetime, str]: + # parse the message according to the format we publish them with + topic, timestamp, msg = multipart + + # parse timestamp + timestamp = datetime.fromisoformat(timestamp.decode()) + + return topic.decode(), timestamp, msg.decode() + + def recv(self) -> tuple[str, datetime, str]: + """Receive a single message and decode it.""" + return self._process_multipart(self._subscriber.recv_multipart()) + + def close(self): + """Close I/O resources.""" + + self._subscriber.close() + self._ctx.destroy() + + self.is_connected = False + + @property + def topics(self): + """Returns the topics of the subscriber""" + return self._topics + + def __enter__(self): + self._thread = Thread(target=self._event_monitor_thread) + self._thread.start() + return self + + def __exit__(self, *args): + with suppress(zmq.ZMQError): + self._subscriber.disable_monitor() + + self._thread.join() + self.close() + + +class AsyncZeroMQSubscriber(ZeroMQSubscriber): + """Asynchronous version of ZeroMQSubscriber. Use `async_recv` instead of `recv` + to receive messages. Usage: + + with AsyncZeroMQSubscriber("tcp://host:port", ["topic"]) as subscriber: + (topic, timestamp, message) = await subscriber.async_recv() + """ + + def __init__( + self, + connect_uri: str, + topics: list[bytes | str], + event_loop=None, + ): + self._event_loop = event_loop or asyncio.get_event_loop() + self._task = None + super().__init__(connect_uri, topics) + + @staticmethod + def _new_zmq_context(): + return zmq.asyncio.Context() + + async def _event_monitor_task(self): + """Task running the event monitor.""" + + logger.info("ZeroMQ event monitor started: %s", self) + + try: + while await self.monitor.poll(): + evt = await recv_monitor_message(self.monitor) + if evt["event"] == zmq.EVENT_MONITOR_STOPPED: + break + + self._handle_event(evt) + except (zmq.error.ContextTerminated, CancelledError): + raise + except Exception: + logger.exception("Error in ZeroMQ event monitor: %s", self) + raise + finally: + logger.info("ZeroMQ event monitor stopped: %s", self) + + async def __aenter__(self): + self._task = self._event_loop.create_task(self._event_monitor_task()) + return self + + def __enter__(self): + raise NotImplementedError("Use async wait instead") + + async def __aexit__(self, *args): + # disable monitor + logger.debug("ZeroMQ teardown stopping monitor: %s", self) + with suppress(zmq.ZMQError): + self._subscriber.disable_monitor() + + # cancel task, do not wait for graceful exit + self._task.cancel() + with suppress(asyncio.CancelledError): + _ = await self._task + + # close sockets & context + logger.debug("ZeroMQ teardown closing socket: %s", self) + self.close() + logger.info("ZeroMQ teardown finished: %s", self) + + async def async_recv(self) -> tuple[str, datetime, str]: + """Receive a single message and decode it.""" + + return self._process_multipart(await self._subscriber.recv_multipart()) + + def recv(self): + raise NotImplementedError("Use async_recv() instead") diff --git a/pyproject.toml b/pyproject.toml index fbfa72dc1bc60ffa20b540d9dff4f52a8aa2d740..848fd574d6cb907d7fce46f53e2bb4b5dc8895ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,10 +12,25 @@ version_file = "lofar_lotus/_version.py" [tool.pylint] ignore = "_version.py" +[tool.ruff] +exclude = [ + ".venv", + ".git", + ".tox", + "dist", + "docs", + "*lib/python*", + "*egg", + "_version.py" +] + +[tool.ruff.lint] +ignore = ["E203"] + [tool.tox] # Generative environment list to test all supported Python versions requires = ["tox>=4.21"] -env_list = ["fix", "pep8", "black", "pylint", "py{13, 12, 11, 10, 9}"] +env_list = ["fix", "coverage", "lint", "format", "py{13, 12, 11, 10}"] [tool.tox.env_run_base] package = "editable" @@ -37,32 +52,21 @@ commands = [ ["python", "-m", "pytest", "--cov-report", "term", "--cov-report", "xml", "--cov-report", "html", "--cov=lofar_lotus"]] # Command prefixes to reuse the same virtualenv for all linting jobs. -[tool.tox.env.pep8] -deps = ["flake8"] -commands = [ - ["python", "-m", "flake8", "--version"], - ["python", "-m", "flake8", { replace = "posargs", default = ["lofar_lotus", "tests"], extend = true }] -] - -[tool.tox.env.black] -deps = ["black"] -commands = [ - ["python", "-m", "black", "--version"], - ["python", "-m", "black", "--check", "--diff", { replace = "posargs", default = ["lofar_lotus", "tests"], extend = true }] -] - -[tool.tox.env.pylint] -deps = ["pylint"] +[tool.tox.env.lint] +deps = [ + "ruff", + "-r{toxinidir}/tests/requirements.txt"] commands = [ - ["python", "-m", "pylint", "--version"], - ["python", "-m", "pylint", { replace = "posargs", default = ["lofar_lotus", "tests"], extend = true }] + ["python", "-m", "ruff", "--version"], + ["python", "-m", "ruff", "check", { replace = "posargs", default = ["lofar_lotus", "tests"], extend = true }] ] [tool.tox.env.format] -deps = ["autopep8", "black"] +deps = [ + "ruff", + "-r{toxinidir}/tests/requirements.txt"] commands = [ - ["python", "-m", "autopep8", "-v", "-aa", "--in-place", "--recursive", { replace = "posargs", default = ["lofar_lotus", "tests"], extend = true }], - ["python", "-m", "black", "-v", { replace = "posargs", default = ["lofar_lotus", "tests"], extend = true }] + ["python", "-m", "ruff", "format", "-v", { replace = "posargs", default = ["lofar_lotus", "tests"], extend = true }] ] [tool.tox.env.docs] diff --git a/requirements.txt b/requirements.txt index 24ce15ab7ead32f98c7ac3edcd34bb2010ff4326..79d7912bf7f3da0d9cb56ca3daf60066ecc3f7c9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,3 @@ numpy +pyzmq>=24 # LGPL + BSD +h5py >= 3.1.0 # BSD diff --git a/setup.cfg b/setup.cfg index 73b314c96eaa616a5d529d55ee5e51086685c036..120fc615f1d0e8acc3a9dfeff8fed82788f5b001 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] -name = lofar_lotus -description = An example package for CI/CD working group +name = lofar-lotus +description = Lots of things used somewhere long_description = file: README.md long_description_content_type = text/markdown url = https://git.astron.nl/templates/python-package @@ -15,7 +15,6 @@ classifiers = Programming Language :: Python Programming Language :: Python :: 3 Programming Language :: Python :: 3 :: Only - Programming Language :: Python :: 3.9 Programming Language :: Python :: 3.10 Programming Language :: Python :: 3.11 Programming Language :: Python :: 3.12 @@ -28,7 +27,7 @@ classifiers = [options] include_package_data = true packages = find: -python_requires = >=3.9 +python_requires = >=3.10 install_requires = file: requirements.txt [flake8] diff --git a/setup.py b/setup.py index 10fdaec810e96f0f1cbedb4a5ddf532c03f50dc4..252001700971fd796c106c47ce0ed62c2f7e8ee6 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ -# Copyright (C) 2023 ASTRON (Netherlands Institute for Radio Astronomy) +# Copyright (C) 2025 ASTRON (Netherlands Institute for Radio Astronomy) # SPDX-License-Identifier: Apache-2.0 -""" Setuptools entry point """ +"""Setuptools entry point""" import setuptools setuptools.setup() diff --git a/tests/dict/__init__.py b/tests/dict/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7ddb7c536b22a31da6fa663fa981e47c92d4f030 --- /dev/null +++ b/tests/dict/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2025 ASTRON (Netherlands Institute for Radio Astronomy) +# SPDX-License-Identifier: Apache-2.0 diff --git a/tests/dict/test_case_insensitive_dict.py b/tests/dict/test_case_insensitive_dict.py new file mode 100644 index 0000000000000000000000000000000000000000..14003dc32d258f38a17f040f7e5fb3603e44847d --- /dev/null +++ b/tests/dict/test_case_insensitive_dict.py @@ -0,0 +1,172 @@ +# Copyright (C) 2025 ASTRON (Netherlands Institute for Radio Astronomy) +# SPDX-License-Identifier: Apache-2.0 + +"""CaseInsensitiveDict test classes""" + +from enum import Enum +from unittest import TestCase + +from lofar_lotus.dict import ( + CaseInsensitiveDict, + CaseInsensitiveString, + ReversibleKeysView, +) + + +class TestCaseInsensitiveDict(TestCase): + def test_set_get_item(self): + """Get and set an item with different casing""" + + t_value = "VALUE" + t_key = "KEY" + t_dict = CaseInsensitiveDict() + + t_dict[t_key] = t_value + + self.assertEqual(t_value, t_dict[t_key.lower()]) + self.assertEqual(t_value, t_dict.get(t_key.lower())) + + def test_set_overwrite(self): + """Overwrite a previous element with different casing""" + + t_value = "VALUE" + t_key = "KEY" + t_dict = CaseInsensitiveDict() + + t_dict[t_key] = t_value + t_dict[t_key.lower()] = t_value.lower() + + self.assertEqual(t_value.lower(), t_dict[t_key]) + self.assertEqual(t_value.lower(), t_dict.get(t_key)) + + class ConstructTestEnum(Enum): + """Test enum class""" + + DICT = "dict" + ITER = "iter" + KWARGS = "kwargs" + + def construct_base(self, test_type: ConstructTestEnum): + """Reusable test method""" + t_key1 = "KEY1" + t_key2 = "key2" + t_value1 = 123 + t_value2 = 456 + t_mapping = {t_key1: t_value1, t_key2: t_value2} + + t_dict = CaseInsensitiveDict() + if test_type is self.ConstructTestEnum.DICT: + t_dict = CaseInsensitiveDict(t_mapping) + elif test_type is self.ConstructTestEnum.ITER: + t_dict = CaseInsensitiveDict(t_mapping.items()) + elif test_type is self.ConstructTestEnum.KWARGS: + t_dict = CaseInsensitiveDict(KEY1=t_value1, key2=t_value2) + + self.assertEqual(t_value1, t_dict[t_key1.lower()]) + self.assertEqual(t_value2, t_dict.get(t_key2.upper())) + + def test_construct_mapping(self): + self.construct_base(self.ConstructTestEnum.DICT) + + def test_construct_iterable(self): + self.construct_base(self.ConstructTestEnum.ITER) + + def test_construct_kwargs(self): + self.construct_base(self.ConstructTestEnum.KWARGS) + + def test_setdefault(self): + t_key = "KEY" + t_value = "value" + t_dict = CaseInsensitiveDict() + + t_dict.setdefault(t_key, t_value) + + self.assertIn(t_key.lower(), t_dict.keys()) + + for key in t_dict.keys(): + self.assertEqual(t_key, key) + self.assertIsInstance(key, CaseInsensitiveString) + + def test_keys(self): + t_key = "KEY" + t_value = "value" + t_dict = CaseInsensitiveDict() + + t_dict[t_key] = t_value + + self.assertIn(t_key.lower(), t_dict.keys()) + self.assertIsInstance(t_dict.keys(), ReversibleKeysView) + + for key in t_dict.keys(): + self.assertEqual(t_key, key) + self.assertIsInstance(key, CaseInsensitiveString) + + def test_items(self): + t_key = "KEY" + t_value = "VALUE" + t_dict = CaseInsensitiveDict() + + t_dict[t_key] = t_value + + for key, value in t_dict.items(): + self.assertEqual(t_key, key) + self.assertIsInstance(key, CaseInsensitiveString) + self.assertNotEqual(t_value.casefold(), value) + + def test_values(self): + t_key = "KEY" + t_value = "VALUE" + t_dict = CaseInsensitiveDict() + + t_dict[t_key] = t_value + + for value in t_dict.values(): + self.assertEqual(t_value, value) + self.assertIsInstance(value, str) + + def test_in(self): + t_key = "KEY" + t_value = "VALUE" + t_dict = CaseInsensitiveDict() + + t_dict[t_key] = t_value + + self.assertIn(t_key.lower(), t_dict) + + def test_reverse(self): + t_key1 = "KEY1" + t_key2 = "KEY2" + t_value = "VALUE" + t_dict = CaseInsensitiveDict() + + t_dict[t_key1] = t_value + t_dict[t_key2] = t_value + + forward = [] + for key, _ in t_dict.items(): + forward.append(key) + + backward = [] + for key, _ in reversed(t_dict.items()): + backward.append(key) + + self.assertEqual(forward[0], backward[1]) + self.assertEqual(forward[1], backward[0]) + + backward = [] + for key in reversed(t_dict.keys()): + backward.append(key) + + self.assertEqual(forward[0], backward[1]) + self.assertEqual(forward[1], backward[0]) + + forward = [] + for item in t_dict.values(): + forward.append(item) + + backward = [] + for value in reversed(t_dict.values()): + backward.append(value) + + self.assertEqual(forward[0], backward[1]) + self.assertEqual(forward[1], backward[0]) diff --git a/tests/dict/test_case_insensitive_string.py b/tests/dict/test_case_insensitive_string.py new file mode 100644 index 0000000000000000000000000000000000000000..b4237540859e269c677b4fb600247fb66a93d550 --- /dev/null +++ b/tests/dict/test_case_insensitive_string.py @@ -0,0 +1,25 @@ +# Copyright (C) 2025 ASTRON (Netherlands Institute for Radio Astronomy) +# SPDX-License-Identifier: Apache-2.0 + +"""CaseInsensitiveString test classes""" + +from unittest import TestCase + +from lofar_lotus.dict import CaseInsensitiveString + + +class TestCaseInsensitiveString(TestCase): + def test_a_in_b(self): + """Get and set an item with different casing""" + + self.assertIn(CaseInsensitiveString("hba0"), CaseInsensitiveString("HBA0")) + + def test_b_in_a(self): + """Get and set an item with different casing""" + + self.assertIn(CaseInsensitiveString("HBA0"), CaseInsensitiveString("hba0")) + + def test_a_not_in_b(self): + """Get and set an item with different casing""" + + self.assertNotIn(CaseInsensitiveString("hba0"), CaseInsensitiveString("LBA0")) diff --git a/tests/file_access/SST_2022-11-15-14-21-39.h5 b/tests/file_access/SST_2022-11-15-14-21-39.h5 new file mode 100644 index 0000000000000000000000000000000000000000..c2ba81b674b6c13d2709cfece7a9fe55f1e34628 Binary files /dev/null and b/tests/file_access/SST_2022-11-15-14-21-39.h5 differ diff --git a/tests/file_access/__init__.py b/tests/file_access/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..479ef86750d1ff6ea6648d1722a270bb8903075d --- /dev/null +++ b/tests/file_access/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2022 ASTRON (Netherlands Institute for Radio Astronomy) +# SPDX-License-Identifier: Apache-2.0 diff --git a/tests/file_access/cal-test-dict.h5 b/tests/file_access/cal-test-dict.h5 new file mode 100644 index 0000000000000000000000000000000000000000..d5eb7cbc95d39fc68c0e81bbe5a1172d97cf84ad Binary files /dev/null and b/tests/file_access/cal-test-dict.h5 differ diff --git a/tests/file_access/cal-test.h5 b/tests/file_access/cal-test.h5 new file mode 100644 index 0000000000000000000000000000000000000000..a28f67eadd33bcec1fb795e4765f825fff3ffe73 Binary files /dev/null and b/tests/file_access/cal-test.h5 differ diff --git a/tests/file_access/test_file_reader.py b/tests/file_access/test_file_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..02fde0adbfcb15fa90f8ed71f2671bb76246a675 --- /dev/null +++ b/tests/file_access/test_file_reader.py @@ -0,0 +1,214 @@ +# Copyright (C) 2025 ASTRON (Netherlands Institute for Radio Astronomy) +# SPDX-License-Identifier: Apache-2.0 + +# pylint: disable=too-few-public-methods +"""File writer tests""" + +from os.path import dirname +from typing import List, Dict +from unittest import TestCase + +from numpy import ndarray + +from lofar_lotus.file_access import member, read_hdf5, attribute + + +class DataSubSet: + """Class to test sub sets""" + + values: List[int] = member() + + +class DataSet: + """Class to test data sets""" + + nof_payload_errors: List[int] = member() + nof_valid_payloads: List[int] = member() + values: List[List[float]] = member() + non_existent: DataSubSet = member(optional=True) + + def __repr__(self): + return f"DataSet(nof_payload_errors={self.nof_payload_errors})" + + +class DataSet2(DataSet): + """Class to test derived data sets""" + + sub_set: DataSubSet = member(name="test") + + +class SimpleDataSet: + """Class to test simple data sets""" + + observation_station: str = attribute() + observation_station_optional: str = attribute(optional=True) + test_attr: str = attribute(from_member="calibration_data", name="test_attribute") + calibration_data: ndarray = member(name="data") + + +class AttrDataSet(SimpleDataSet): + """Class to test attributes""" + + observation_station_missing_none_optional: str = attribute(optional=False) + + +class CalData: + """Class to test attributes""" + + x_attr: str = attribute("test_attr", from_member="x") + y_attr: str = attribute("test_attr", from_member="y") + x: ndarray = member() + y: ndarray = member() + + +class CalTable(Dict[str, CalData]): + """Class to test dictionaries""" + + observation_station: str = attribute() + + +class CalTableDict(Dict[str, Dict[str, ndarray]]): + """Class to test multidimensional dictionaries""" + + +class TestHdf5FileReader(TestCase): + def test_file_reading(self): + with read_hdf5( + dirname(__file__) + "/SST_2022-11-15-14-21-39.h5", Dict[str, DataSet2] + ) as ds: + self.assertEqual(21, len(ds.keys())) + item = ds["SST_2022-11-15T14:21:59.000+00:00"] + self.assertEqual( + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + item.nof_payload_errors, + ) + # double read to check if (cached) value is the same + self.assertEqual( + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + item.nof_payload_errors, + ) + self.assertEqual( + [12, 12, 12, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + item.nof_valid_payloads, + ) + self.assertIsNone(item.non_existent) + self.assertEqual(192, len(item.values)) + + with self.assertRaises(KeyError): + _ = ( + item.sub_set + # this item does not exist but is not marked as optional + ) + + item = ds["SST_2022-11-15T14:21:39.000+00:00"] + self.assertEqual(100, len(item.sub_set.values)) + + def test_read_attribute(self): + with read_hdf5(dirname(__file__) + "/cal-test.h5", AttrDataSet) as ds: + self.assertEqual("test-station", ds.observation_station) + self.assertEqual("dset_attr", ds.test_attr) + self.assertIsNone(ds.observation_station_optional) + self.assertEqual("dset_attr", ds.test_attr) # test caching + with self.assertRaises(KeyError): + _ = ( + ds.observation_station_missing_none_optional + # this attribute does not exist but is not marked as optional + ) + + def test_load_object(self): + hdf5_file = read_hdf5(dirname(__file__) + "/cal-test.h5", SimpleDataSet) + ds = hdf5_file.read() + hdf5_file.load(ds) + hdf5_file.close() + self.assertEqual("test-station", ds.observation_station) + self.assertIsNone(ds.observation_station_optional) + self.assertEqual("dset_attr", ds.test_attr) + d = ds.calibration_data + self.assertIsInstance(d, ndarray) + self.assertEqual(512, d.shape[0]) + self.assertEqual(96, d.shape[1]) + + def test_load_complex(self): + hdf5_file = read_hdf5( + dirname(__file__) + "/SST_2022-11-15-14-21-39.h5", Dict[str, DataSet] + ) + test = [] + with hdf5_file as ds: + for _, data in ds.items(): + hdf5_file.load(data) + test.append(data) + + def test_read_ndarray(self): + with read_hdf5(dirname(__file__) + "/cal-test.h5", AttrDataSet) as ds: + d = ds.calibration_data + self.assertIsInstance(d, ndarray) + self.assertEqual(512, d.shape[0]) + self.assertEqual(96, d.shape[1]) + + def test_read_derived_dict(self): + with read_hdf5(dirname(__file__) + "/cal-test-dict.h5", CalTable) as ds: + self.assertEqual(5, len(ds)) + self.assertEqual("test-station", ds.observation_station) + ant_2 = ds["ant_2"] + self.assertEqual(512, len(ant_2.x)) + self.assertEqual(512, len(ant_2.y)) + self.assertEqual("ant_2_x", ant_2.x_attr) + self.assertEqual("ant_2_y", ant_2.y_attr) + + def test_read_derived_double_dict(self): + with read_hdf5(dirname(__file__) + "/cal-test-dict.h5", CalTableDict) as ds: + self.assertEqual(5, len(ds)) + ant_2 = ds["ant_2"] + self.assertIn("x", ant_2) + self.assertIn("y", ant_2) + self.assertEqual(512, len(ant_2["x"])) + self.assertEqual(512, len(ant_2["y"])) + + def test_read_as_object(self): + class ObjectDataSet: + """Class to test object data sets""" + + item_1: Dict[str, List[int]] = member( + name="SST_2022-11-15T14:21:59.000+00:00" + ) + item_2: Dict[str, List[int]] = member( + name="SST_2022-11-15T14:21:39.000+00:00" + ) + + with read_hdf5( + dirname(__file__) + "/SST_2022-11-15-14-21-39.h5", ObjectDataSet + ) as ds: + self.assertEqual( + ["nof_payload_errors", "nof_valid_payloads", "values"], + list(ds.item_1.keys()), + ) + with self.assertRaises(TypeError): + _ = ( + ds.item_2["test"] + # item test is of type group and will raise an error + ) + + def test_malformed_data(self): + class BrokenDataSet: + """Class to test broken data sets""" + + nof_payload_errors: DataSubSet = member() + nof_valid_payloads: int = member() + sub_set: List[int] = member(name="test") + + with read_hdf5( + dirname(__file__) + "/SST_2022-11-15-14-21-39.h5", Dict[str, BrokenDataSet] + ) as ds: + item = ds["SST_2022-11-15T14:21:39.000+00:00"] + with self.assertRaises(TypeError): + _ = item.nof_payload_errors + with self.assertRaises(TypeError): + _ = item.nof_valid_payloads + with self.assertRaises(TypeError): + _ = item.sub_set + + def test_reader_close(self): + file_reader = read_hdf5( + dirname(__file__) + "/SST_2022-11-15-14-21-39.h5", Dict[str, DataSet] + ) + file_reader.close() diff --git a/tests/file_access/test_file_writer.py b/tests/file_access/test_file_writer.py new file mode 100644 index 0000000000000000000000000000000000000000..528c13e5f5e6c48dc31f897a3d95ca32ce1a271f --- /dev/null +++ b/tests/file_access/test_file_writer.py @@ -0,0 +1,250 @@ +# Copyright (C) 2025 ASTRON (Netherlands Institute for Radio Astronomy) +# SPDX-License-Identifier: Apache-2.0 + +# pylint: disable=too-few-public-methods +"""File writer tests""" + +from tempfile import TemporaryDirectory +from typing import List, Dict +from unittest import TestCase + +import numpy +from numpy import ndarray, array + +from lofar_lotus.file_access import ( + member, + attribute, + create_hdf5, + read_hdf5, + open_hdf5, +) + + +class SimpleSet: + """Class to test sets""" + + values: ndarray = member() + + +class DataSubSet: + """Class to test sub sets""" + + values: List[int] = member() + dict_test_ndarray: Dict[str, ndarray] = member() + dict_test_object: Dict[str, SimpleSet] = member() + + +class DataSet: + """Class to test data sets""" + + observation_station: str = attribute() + observation_source: str = attribute(from_member="sub_set") + nof_payload_errors: List[int] = member() + values: List[List[float]] = member() + sub_set: DataSubSet = member(name="test") + non_existent: DataSubSet = member(optional=True) + + +class SubArray(ndarray): + """Class to test derived ndarrays""" + + observation_station: str = attribute() + + +class SubArrayDerived(SubArray): + """Class to test derived derived ndarrays""" + + observation_source: str = attribute() + + +class SubDict(Dict[str, ndarray]): + """Class to test deribed dictionaries""" + + station_name: str = attribute() + station_version: str = attribute() + + +class TestHdf5FileWriter(TestCase): + def test_simple_writing(self): + with TemporaryDirectory() as tmpdir: + file_name = tmpdir + "/test_simple_writing.h5" + + with create_hdf5(file_name, DataSet) as ds: + ds.observation_station = "CS001" + ds.nof_payload_errors = [1, 2, 3, 4, 5, 6] + ds.values = [[2.0], [3.0], [4.0]] + ds.sub_set = DataSubSet() + ds.sub_set.values = [5, 4, 3, 2] + ds.observation_source = "CasA" + + with read_hdf5(file_name, DataSet) as ds: + self.assertEqual("CS001", ds.observation_station) + self.assertEqual([1, 2, 3, 4, 5, 6], ds.nof_payload_errors) + self.assertEqual([[2.0], [3.0], [4.0]], ds.values) + self.assertIsNotNone(ds.sub_set) + self.assertEqual([5, 4, 3, 2], ds.sub_set.values) + self.assertEqual("CasA", ds.observation_source) + + def test_list_writing(self): + with TemporaryDirectory() as tmpdir: + file_name = tmpdir + "/test_list_writing.h5" + + with create_hdf5(file_name, DataSubSet) as dss: + dss.values = [2, 3, 4, 5] + dss.values.append(1) + + with read_hdf5(file_name, DataSubSet) as dss: + self.assertEqual([2, 3, 4, 5, 1], dss.values) + + def test_dict_writing(self): + with TemporaryDirectory() as tmpdir: + file_name = tmpdir + "/test_dict_writing.h5" + + with create_hdf5(file_name, Dict[str, ndarray]) as d: + d["test_1"] = array([1, 2, 3, 4, 5, 6]) + d["test_2"] = array([6, 5, 4, 1]) + + with read_hdf5(file_name, Dict[str, ndarray]) as d: + self.assertFalse((array([1, 2, 3, 4, 5, 6]) - d["test_1"]).any()) + self.assertFalse((array([6, 5, 4, 1]) - d["test_2"]).any()) + + def test_derived_dict_writing(self): + with TemporaryDirectory() as tmpdir: + file_name = tmpdir + "/test_derived_dict_writing.h5" + + with create_hdf5(file_name, SubDict) as d: + d.station_name = "st1" + d.station_version = "999" + d["test_1"] = array([1, 2, 3, 4, 5, 6]) + d["test_2"] = array([6, 5, 4, 1]) + + with read_hdf5(file_name, SubDict) as d: + self.assertEqual("st1", d.station_name) + self.assertEqual("999", d.station_version) + self.assertFalse((array([1, 2, 3, 4, 5, 6]) - d["test_1"]).any()) + self.assertFalse((array([6, 5, 4, 1]) - d["test_2"]).any()) + + def test_derived_ndarray_writing(self): + with TemporaryDirectory() as tmpdir: + file_name = tmpdir + "/test_derived_ndarray_writing.h5" + + with create_hdf5(file_name, Dict[str, SubArray]) as d: + sa = numpy.zeros((8,), dtype=numpy.float64).view(SubArray) + sa.observation_station = "test1" + d["test_1"] = sa + sa2 = numpy.zeros((10,), dtype=numpy.float64).view(SubArray) + sa2.observation_station = "test2" + d["test_2"] = sa2 + + with read_hdf5(file_name, Dict[str, SubArray]) as dss: + self.assertIn("test_1", dss) + self.assertIn("test_2", dss) + self.assertEqual("test1", dss["test_1"].observation_station) + self.assertEqual("test2", dss["test_2"].observation_station) + + def test_doubly_derived_ndarray_writing(self): + with TemporaryDirectory() as tmpdir: + file_name = tmpdir + "/test_doubly_derived_ndarray_writing.h5" + + with create_hdf5(file_name, Dict[str, SubArrayDerived]) as d: + sa = numpy.zeros((8,), dtype=numpy.float64).view(SubArrayDerived) + sa.observation_station = "test1" + sa.observation_source = "source1" + d["test_1"] = sa + + with read_hdf5(file_name, Dict[str, SubArrayDerived]) as dss: + self.assertIn("test_1", dss) + self.assertEqual("test1", dss["test_1"].observation_station) + self.assertEqual("source1", dss["test_1"].observation_source) + + def test_dict_altering(self): + with TemporaryDirectory() as tmpdir: + file_name = tmpdir + "/test_dict_altering.h5" + + with create_hdf5(file_name, DataSubSet) as dss: + dss.dict_test_ndarray = { + "test_1": array([2, 4, 6]), + "test_2": array([1, 3, 5]), + } + dss.dict_test_ndarray["test_3"] = array([9, 8, 7]) + dss.dict_test_ndarray.pop("test_1") + ss = SimpleSet() + ss.values = array([4, 9, 3]) + dss.dict_test_object = {"test_99": ss} + dss.dict_test_object["test_99"].values[0] = 5 + dss.dict_test_object["test_98"] = SimpleSet() + dss.dict_test_object["test_98"].values = array([4, 9, 3]) + + with read_hdf5(file_name, DataSubSet) as dss: + self.assertIn("test_2", dss.dict_test_ndarray) + self.assertIn("test_3", dss.dict_test_ndarray) + self.assertFalse(([1, 3, 5] - dss.dict_test_ndarray["test_2"]).any()) + self.assertFalse(([9, 8, 7] - dss.dict_test_ndarray["test_3"]).any()) + self.assertIn("test_99", dss.dict_test_object) + self.assertIn("test_98", dss.dict_test_object) + self.assertFalse( + ([5, 9, 3] - dss.dict_test_object["test_99"].values).any() + ) + self.assertFalse( + ([4, 9, 3] - dss.dict_test_object["test_98"].values).any() + ) + + def test_object_access(self): + ds = DataSet() + ds.observation_station = "CS001" + ds.nof_payload_errors = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ds.values = [[1.0]] + ds.sub_set = DataSubSet() + ds.sub_set.values = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ds.observation_source = "CasA" + + self.assertEqual("CS001", ds.observation_station) + self.assertEqual( + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], ds.nof_payload_errors + ) + self.assertEqual([[1.0]], ds.values) + self.assertIsNotNone(ds.sub_set) + self.assertEqual( + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], ds.sub_set.values + ) + self.assertEqual("CasA", ds.observation_source) + + def test_attach_object(self): + with TemporaryDirectory() as tmpdir: + file_name = tmpdir + "/test_attach_object.h5" + + with create_hdf5(file_name, DataSet) as ds: + sub_set = DataSubSet() + sub_set.values = [7, 4, 9, 2, 9] + ds.sub_set = sub_set + ds.observation_source = "CasA" + + with read_hdf5(file_name, DataSet) as ds: + self.assertEqual([7, 4, 9, 2, 9], ds.sub_set.values) + self.assertEqual("CasA", ds.observation_source) + + def test_open_write(self): + with TemporaryDirectory() as tmpdir: + file_name = tmpdir + "/test_open_write.h5" + + with create_hdf5(file_name, DataSet) as ds: + ds.observation_station = "CS001" + ds.nof_payload_errors = [1, 2, 3, 4, 5, 6] + ds.values = [[2.0], [3.0], [4.0]] + ds.sub_set = DataSubSet() + ds.sub_set.values = [5, 4, 3, 2] + ds.observation_source = "CasA" + + with open_hdf5(file_name, DataSet) as ds: + ds.nof_payload_errors.append(7) + ds.values.append([5.0]) + ds.observation_source = "ACAS" + ds.sub_set.values = [1, 2, 3] + + with read_hdf5(file_name, DataSet) as ds: + self.assertEqual("CS001", ds.observation_station) + self.assertEqual([1, 2, 3, 4, 5, 6, 7], ds.nof_payload_errors) + self.assertEqual([[2.0], [3.0], [4.0], [5.0]], ds.values) + self.assertIsNotNone(ds.sub_set) + self.assertEqual([1, 2, 3], ds.sub_set.values) + self.assertEqual("ACAS", ds.observation_source) diff --git a/tests/file_access/test_lazy_dict.py b/tests/file_access/test_lazy_dict.py new file mode 100644 index 0000000000000000000000000000000000000000..d5c8355300d90b664d590f432765ef9e63368c4c --- /dev/null +++ b/tests/file_access/test_lazy_dict.py @@ -0,0 +1,60 @@ +# Copyright (C) 2025 ASTRON (Netherlands Institute for Radio Astronomy) +# SPDX-License-Identifier: Apache-2.0 + +"""LazyDict tests""" + +from typing import Dict +from unittest import TestCase + +from lofar_lotus.file_access._lazy_dict import lazy_dict + + +class TestLazyDict(TestCase): + def test_dict_read(self): + invocations = [] + data = {1: 99, 8: 55, 98: 3} + dict_type = Dict[int, int] + + def reader(key): + invocations.append(f"Invoked with {key}") + return data[key] + + d1 = lazy_dict(dict_type, reader) + d1[1] = lambda: reader(1) + d1[8] = lambda: reader(8) + d1[98] = lambda: reader(98) + + self.assertEqual(99, d1[1]) + self.assertEqual(55, d1[8]) + self.assertEqual(3, d1[98]) + self.assertEqual("Invoked with 1", invocations[0]) + self.assertEqual("Invoked with 8", invocations[1]) + self.assertEqual("Invoked with 98", invocations[2]) + + def test_dict_write(self): + invocations = [] + data = {} + dict_type = Dict[int, int] + + def reader(key): + invocations.append(f"Invoked with {key}") + return data[key] + + def writer(key, value): + invocations.append(f"Invoked with {key} = {value}") + data[key] = value + + d1 = lazy_dict(dict_type, reader) + d1.setup_write(writer) + + d1[1] = 2 + self.assertEqual("Invoked with 1 = 2", invocations[0]) + d1[2] = 3 + self.assertEqual("Invoked with 2 = 3", invocations[1]) + d1[1] = 4 + self.assertEqual("Invoked with 1 = 4", invocations[2]) + + self.assertEqual(4, d1[1]) + self.assertEqual(3, d1[2]) + self.assertEqual("Invoked with 1", invocations[3]) + self.assertEqual("Invoked with 2", invocations[4]) diff --git a/tests/file_access/test_monitored_wrapper.py b/tests/file_access/test_monitored_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..a1db6d76bb6ac916bb3827fd0e914bf9430c3dc0 --- /dev/null +++ b/tests/file_access/test_monitored_wrapper.py @@ -0,0 +1,41 @@ +# Copyright (C) 2025 ASTRON (Netherlands Institute for Radio Astronomy) +# SPDX-License-Identifier: Apache-2.0 + +"""MonitoredWrapper tests""" + +from unittest import TestCase + +from numpy import array + +from lofar_lotus.file_access._monitoring import MonitoredWrapper + + +class TestMonitoredWrapper(TestCase): + def test_list(self): + invocations = [] + + def event(a): + invocations.append(f"Invoked with {a}") + + l1 = MonitoredWrapper(event, []) + l1.append(1) + self.assertEqual("Invoked with [1]", invocations[0]) + l1.append(2) + self.assertEqual("Invoked with [1, 2]", invocations[1]) + l1.pop() + self.assertEqual("Invoked with [1]", invocations[2]) + + l2 = MonitoredWrapper(event, [1, 2, 3, 4]) + l2.append(1) + self.assertEqual("Invoked with [1, 2, 3, 4, 1]", invocations[3]) + l2.append(2) + self.assertEqual("Invoked with [1, 2, 3, 4, 1, 2]", invocations[4]) + l2.pop() + self.assertEqual("Invoked with [1, 2, 3, 4, 1]", invocations[5]) + + l2[0] = 99 + self.assertEqual(99, l2[0]) + self.assertEqual("Invoked with [99, 2, 3, 4, 1]", invocations[6]) + + na = MonitoredWrapper(event, array([2, 3, 4])) + self.assertEqual((3,), na.shape) diff --git a/tests/requirements.txt b/tests/requirements.txt index b507faf8c6cc09660c41c58caa26d318e94cfd4d..706f3981a563bedc2fea74247d46ee35976eb2ff 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,2 +1,4 @@ pytest >= 7.0.0 # MIT pytest-cov >= 3.0.0 # MIT +timeout-decorator >= 0.5.0 # MIT +setuptools>=70.0 diff --git a/tests/test_cool_module.py b/tests/test_cool_module.py deleted file mode 100644 index 930557e5c252b8e29c430fd3cf3b9acb6c472153..0000000000000000000000000000000000000000 --- a/tests/test_cool_module.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (C) 2023 ASTRON (Netherlands Institute for Radio Astronomy) -# SPDX-License-Identifier: Apache-2.0 - -"""Testing of the Cool Module""" -from unittest import TestCase - -from lofar_lotus.cool_module import greeter - - -class TestCoolModule(TestCase): - """Test Case of the Cool Module""" - - def test_greeter(self): - """Testing that the greeter does not crash""" - greeter() - self.assertEqual(2 + 2, 4) diff --git a/tests/zeromq/__init__.py b/tests/zeromq/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/zeromq/test_configdb.json b/tests/zeromq/test_configdb.json new file mode 100644 index 0000000000000000000000000000000000000000..88f5711101e911f0f5e91bdd197753d1500b940b --- /dev/null +++ b/tests/zeromq/test_configdb.json @@ -0,0 +1,1004 @@ +{ + "servers": { + "Boot": { + "STAT": { + "Boot": { + "STAT/Boot/1": { + "properties": { + "Initialise_Hardware": [ + "False" + ] + } + } + } + } + }, + "APSCT": { + "STAT": { + "APSCT": { + "STAT/APSCT/L0": { + "properties": { + "OPC_Server_Name": [ + "apsct-sim" + ], + "OPC_Server_Port": [ + "4843" + ], + "OPC_Time_Out": [ + "5.0" + ] + } + }, + "STAT/APSCT/L1": { + "properties": { + "OPC_Server_Name": [ + "apsct-sim" + ], + "OPC_Server_Port": [ + "4843" + ], + "OPC_Time_Out": [ + "5.0" + ] + } + }, + "STAT/APSCT/H0": { + "properties": { + "OPC_Server_Name": [ + "apsct-sim" + ], + "OPC_Server_Port": [ + "4843" + ], + "OPC_Time_Out": [ + "5.0" + ] + } + } + } + } + }, + "CCD": { + "STAT": { + "CCD": { + "STAT/CCD/1": { + "properties": { + "OPC_Server_Name": [ + "ccd-sim" + ], + "OPC_Server_Port": [ + "4843" + ], + "OPC_Time_Out": [ + "5.0" + ] + } + } + } + } + }, + "EC": { + "STAT": { + "EC": { + "STAT/EC/1": { + "properties": { + "OPC_Server_Name": [ + "ec-sim" + ], + "OPC_Server_Port": [ + "4840" + ], + "OPC_Time_Out": [ + "5.0" + ], + "OPC_Node_Path_Prefix": [ + "3:ServerInterfaces", + "4:Environmental_Control" + ], + "OPC_namespace": [ + "http://Environmental_Control" + ] + } + } + } + } + }, + "APSPU": { + "STAT": { + "APSPU": { + "STAT/APSPU/L0": { + "properties": { + "OPC_Server_Name": [ + "apspu-sim" + ], + "OPC_Server_Port": [ + "4842" + ], + "OPC_Time_Out": [ + "5.0" + ] + } + }, + "STAT/APSPU/L1": { + "properties": { + "OPC_Server_Name": [ + "apspu-sim" + ], + "OPC_Server_Port": [ + "4842" + ], + "OPC_Time_Out": [ + "5.0" + ] + } + }, + "STAT/APSPU/H0": { + "properties": { + "OPC_Server_Name": [ + "apspu-sim" + ], + "OPC_Server_Port": [ + "4842" + ], + "OPC_Time_Out": [ + "5.0" + ] + } + } + } + } + }, + "Beamlet": { + "STAT": { + "Beamlet": { + "STAT/Beamlet/LBA": { + "properties": { + "OPC_Server_Name": [ + "sdptr-sim" + ], + "OPC_Server_Port": [ + "4840" + ], + "OPC_Time_Out": [ + "5.0" + ], + "FPGA_beamlet_output_hdr_eth_destination_mac_RW_default": [ + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB" + ], + "FPGA_beamlet_output_hdr_ip_destination_address_RW_default": [ + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1" + ] + } + }, + "STAT/Beamlet/HBA0": { + "properties": { + "OPC_Server_Name": [ + "sdptr-sim" + ], + "OPC_Server_Port": [ + "4840" + ], + "OPC_Time_Out": [ + "5.0" + ], + "FPGA_beamlet_output_hdr_eth_destination_mac_RW_default": [ + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB" + ], + "FPGA_beamlet_output_hdr_ip_destination_address_RW_default": [ + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1" + ] + } + }, + "STAT/Beamlet/HBA1": { + "properties": { + "OPC_Server_Name": [ + "sdptr-sim" + ], + "OPC_Server_Port": [ + "4840" + ], + "OPC_Time_Out": [ + "5.0" + ], + "FPGA_beamlet_output_hdr_eth_destination_mac_RW_default": [ + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB" + ], + "FPGA_beamlet_output_hdr_ip_destination_address_RW_default": [ + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1" + ] + } + } + } + } + }, + "DigitalBeam": { + "STAT": { + "DigitalBeam": { + "STAT/DigitalBeam/LBA": { + "properties": { + } + }, + "STAT/DigitalBeam/HBA0": { + "properties": { + } + }, + "STAT/DigitalBeam/HBA1": { + "properties": { + } + } + } + } + }, + "ProtectionControl": { + "STAT": { + "ProtectionControl": { + "STAT/ProtectionControl/1": { + "properties": { + } + } + } + } + }, + "PCON": { + "STAT": { + "PCON": { + "STAT/PCON/1": { + "properties": { + "SNMP_use_simulators": [ + "True" + ] + } + } + } + } + }, + "PSOC": { + "STAT": { + "PSOC": { + "STAT/PSOC/1": { + "properties": { + "SNMP_use_simulators": [ + "True" + ] + } + } + } + } + }, + "RECVH": { + "STAT": { + "RECVH": { + "STAT/RECVH/H0": { + "properties": { + "OPC_Server_Name": [ + "recvh-sim" + ], + "OPC_Server_Port": [ + "4844" + ], + "OPC_Time_Out": [ + "5.0" + ] + } + } + } + } + }, + "RECVL": { + "STAT": { + "RECVL": { + "STAT/RECVL/L1": { + "properties": { + "OPC_Server_Name": [ + "recvl-sim" + ], + "OPC_Server_Port": [ + "4845" + ], + "OPC_Time_Out": [ + "5.0" + ] + } + }, + "STAT/RECVL/L0": { + "properties": { + "OPC_Server_Name": [ + "recvl-sim" + ], + "OPC_Server_Port": [ + "4845" + ], + "OPC_Time_Out": [ + "5.0" + ] + } + } + } + } + }, + "SDPFirmware": { + "STAT": { + "SDPFirmware": { + "STAT/SDPFirmware/LBA": { + "properties": { + "OPC_Server_Name": [ + "sdptr-sim" + ], + "OPC_Server_Port": [ + "4840" + ], + "OPC_Time_Out": [ + "5.0" + ] + } + }, + "STAT/SDPFirmware/HBA0": { + "properties": { + "OPC_Server_Name": [ + "sdptr-sim" + ], + "OPC_Server_Port": [ + "4840" + ], + "OPC_Time_Out": [ + "5.0" + ] + } + }, + "STAT/SDPFirmware/HBA1": { + "properties": { + "OPC_Server_Name": [ + "sdptr-sim" + ], + "OPC_Server_Port": [ + "4840" + ], + "OPC_Time_Out": [ + "5.0" + ] + } + } + } + } + }, + "SDP": { + "STAT": { + "SDP": { + "STAT/SDP/LBA": { + "properties": { + "OPC_Server_Name": [ + "sdptr-sim" + ], + "OPC_Server_Port": [ + "4840" + ], + "OPC_Time_Out": [ + "5.0" + ] + } + }, + "STAT/SDP/HBA0": { + "properties": { + "OPC_Server_Name": [ + "sdptr-sim" + ], + "OPC_Server_Port": [ + "4840" + ], + "OPC_Time_Out": [ + "5.0" + ] + } + }, + "STAT/SDP/HBA1": { + "properties": { + "OPC_Server_Name": [ + "sdptr-sim" + ], + "OPC_Server_Port": [ + "4840" + ], + "OPC_Time_Out": [ + "5.0" + ] + } + } + } + } + }, + "BST": { + "STAT": { + "BST": { + "STAT/BST/LBA": { + "properties": { + "OPC_Server_Name": [ + "sdptr-sim" + ], + "OPC_Server_Port": [ + "4840" + ], + "OPC_Time_Out": [ + "5.0" + ], + "FPGA_bst_offload_hdr_eth_destination_mac_RW_default": [ + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB" + ], + "FPGA_bst_offload_hdr_ip_destination_address_RW_default": [ + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1" + ] + } + }, + "STAT/BST/HBA0": { + "properties": { + "OPC_Server_Name": [ + "sdptr-sim" + ], + "OPC_Server_Port": [ + "4840" + ], + "OPC_Time_Out": [ + "5.0" + ], + "FPGA_bst_offload_hdr_eth_destination_mac_RW_default": [ + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB" + ], + "FPGA_bst_offload_hdr_ip_destination_address_RW_default": [ + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1" + ] + } + }, + "STAT/BST/HBA1": { + "properties": { + "OPC_Server_Name": [ + "sdptr-sim" + ], + "OPC_Server_Port": [ + "4840" + ], + "OPC_Time_Out": [ + "5.0" + ], + "FPGA_bst_offload_hdr_eth_destination_mac_RW_default": [ + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB" + ], + "FPGA_bst_offload_hdr_ip_destination_address_RW_default": [ + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1" + ] + } + } + } + } + }, + "SST": { + "STAT": { + "SST": { + "STAT/SST/LBA": { + "properties": { + "OPC_Server_Name": [ + "sdptr-sim" + ], + "OPC_Server_Port": [ + "4840" + ], + "OPC_Time_Out": [ + "5.0" + ], + "FPGA_sst_offload_hdr_eth_destination_mac_RW_default": [ + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB" + ], + "FPGA_sst_offload_hdr_ip_destination_address_RW_default": [ + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1" + ] + } + }, + "STAT/SST/HBA0": { + "properties": { + "OPC_Server_Name": [ + "sdptr-sim" + ], + "OPC_Server_Port": [ + "4840" + ], + "OPC_Time_Out": [ + "5.0" + ], + "FPGA_sst_offload_hdr_eth_destination_mac_RW_default": [ + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB" + ], + "FPGA_sst_offload_hdr_ip_destination_address_RW_default": [ + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1" + ] + } + }, + "STAT/SST/HBA1": { + "properties": { + "OPC_Server_Name": [ + "sdptr-sim" + ], + "OPC_Server_Port": [ + "4840" + ], + "OPC_Time_Out": [ + "5.0" + ], + "FPGA_sst_offload_hdr_eth_destination_mac_RW_default": [ + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB" + ], + "FPGA_sst_offload_hdr_ip_destination_address_RW_default": [ + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1" + ] + } + } + } + } + }, + "XST": { + "STAT": { + "XST": { + "STAT/XST/LBA": { + "properties": { + "OPC_Server_Name": [ + "sdptr-sim" + ], + "OPC_Server_Port": [ + "4840" + ], + "OPC_Time_Out": [ + "5.0" + ], + "FPGA_xst_offload_hdr_eth_destination_mac_RW_default": [ + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB" + ], + "FPGA_xst_offload_hdr_ip_destination_address_RW_default": [ + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1" + ] + } + }, + "STAT/XST/HBA0": { + "properties": { + "OPC_Server_Name": [ + "sdptr-sim" + ], + "OPC_Server_Port": [ + "4840" + ], + "OPC_Time_Out": [ + "5.0" + ], + "FPGA_xst_offload_hdr_eth_destination_mac_RW_default": [ + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB" + ], + "FPGA_xst_offload_hdr_ip_destination_address_RW_default": [ + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1" + ] + } + }, + "STAT/XST/HBA1": { + "properties": { + "OPC_Server_Name": [ + "sdptr-sim" + ], + "OPC_Server_Port": [ + "4840" + ], + "OPC_Time_Out": [ + "5.0" + ], + "FPGA_xst_offload_hdr_eth_destination_mac_RW_default": [ + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB", + "01:23:45:67:89:AB" + ], + "FPGA_xst_offload_hdr_ip_destination_address_RW_default": [ + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1", + "127.0.0.1" + ] + } + } + } + } + }, + "UNB2": { + "STAT": { + "UNB2": { + "STAT/UNB2/L0": { + "properties": { + "OPC_Server_Name": [ + "unb2-sim" + ], + "OPC_Server_Port": [ + "4841" + ], + "OPC_Time_Out": [ + "5.0" + ] + } + }, + "STAT/UNB2/L1": { + "properties": { + "OPC_Server_Name": [ + "unb2-sim" + ], + "OPC_Server_Port": [ + "4841" + ], + "OPC_Time_Out": [ + "5.0" + ] + } + }, + "STAT/UNB2/H0": { + "properties": { + "OPC_Server_Name": [ + "unb2-sim" + ], + "OPC_Server_Port": [ + "4841" + ], + "OPC_Time_Out": [ + "5.0" + ] + } + } + } + } + } + } +} diff --git a/tests/zeromq/test_publisher.py b/tests/zeromq/test_publisher.py new file mode 100644 index 0000000000000000000000000000000000000000..d816cd1cf49bf4e5e78d11bae91b3e783a6e58f9 --- /dev/null +++ b/tests/zeromq/test_publisher.py @@ -0,0 +1,249 @@ +# Copyright (C) 2025 ASTRON (Netherlands Institute for Radio Astronomy) +# SPDX-License-Identifier: Apache-2.0 + +"""ZeroMQPublisher test classes""" + +import json +import logging +import queue +import time +from ctypes import c_int8 +from datetime import datetime +from importlib.resources import files +from multiprocessing.sharedctypes import Value +from threading import Thread +from typing import Any +from unittest import mock, TestCase + +import zmq +from timeout_decorator import timeout_decorator +from zmq.utils.monitor import recv_monitor_message + +from lofar_lotus.zeromq import ZeroMQPublisher + +logger = logging.getLogger() + + +class TestPublisher(TestCase): + DEFAULT_PUBLISH_ADDRESS = "tcp://*:6001" + DEFAULT_SUBSCRIBE_ADDRESS = "tcp://127.0.0.1:6001" + + @staticmethod + def event_monitor_loop(monitor: zmq.Socket, trigger: Value): + """Loop on monitor socket to count number of subscribers + + :param monitor: zmq monit socket, use `get_monitor_socket()` + :param trigger: multiprocessing shared value, must be incrementable / number + """ + while monitor.poll(): + evt: dict[str, Any] = {} + mon_evt: dict = recv_monitor_message(monitor) + evt.update(mon_evt) + evt["description"] = evt["event"] + logger.warning("Event: %s", evt) + if evt["event"] == zmq.EVENT_HANDSHAKE_SUCCEEDED: + logger.info("Setting connected to true") + trigger.value += 1 + elif evt["event"] == zmq.EVENT_DISCONNECTED: + logger.info("Dropping connection") + trigger.value -= 1 + elif evt["event"] == zmq.EVENT_MONITOR_STOPPED: + break + monitor.close() + + @staticmethod + def create_event_monitor(monitor: zmq.Socket, trigger: Value): + """Create a thread that uses an event monitor socket""" + t_monitor = Thread( + target=TestPublisher.event_monitor_loop, args=(monitor, trigger) + ) + t_monitor.start() + return t_monitor + + @staticmethod + def wait_for_start(publisher: ZeroMQPublisher): + """Spin until publisher is running""" + while not publisher.is_running: + logger.info("Waiting for publisher thread to start..") + time.sleep(0.1) + + @staticmethod + def load_test_json(): + """Load test_configdb into memory from disc""" + file_path = files(__package__).joinpath("test_configdb.json") + with file_path.open() as _file: + return json.dumps(json.load(_file)) + + def test_contstruct_bind_uri(self): + """Test that helper function creates proper strings""" + + self.assertEqual( + "tcp://0.0.0.0:1624", + ZeroMQPublisher.construct_bind_uri("tcp", "0.0.0.0", 1624), + ) + + self.assertEqual( + "udp://0.0.0.0:1624", + ZeroMQPublisher.construct_bind_uri("udp", "0.0.0.0", "1624"), + ) + + @timeout_decorator.timeout(5) + def test_topic_bytearray(self): + """Pass a list of topics as bytearray""" + t_topics = [b"A", b"B"] + + with ZeroMQPublisher(self.DEFAULT_PUBLISH_ADDRESS, t_topics) as t_publisher: + self.assertListEqual(t_topics, t_publisher.topics) + + @timeout_decorator.timeout(5) + def test_topic_str(self): + """Pass a list of topics as str""" + t_topics = ["A", "B"] + + with ZeroMQPublisher(self.DEFAULT_PUBLISH_ADDRESS, t_topics) as t_publisher: + self.assertNotEqual(t_topics, t_publisher.topics) + + t_topics = [b"A", b"B"] + + self.assertListEqual(t_topics, t_publisher.topics) + + @timeout_decorator.timeout(5) + def test_start_stop(self): + """Test the startup and shutdown sequence""" + t_topic = b"A" + with ZeroMQPublisher(self.DEFAULT_PUBLISH_ADDRESS, [t_topic]) as t_publisher: + self.wait_for_start(t_publisher) + + self.assertTrue(t_publisher.is_running) + self.assertFalse(t_publisher.is_stopping) + self.assertFalse(t_publisher.is_done) + + t_publisher.shutdown() + + self.assertTrue(t_publisher.is_stopping) + + while not t_publisher.is_done: + logger.info("Waiting for publisher thread to stop..") + time.sleep(0.1) + + self.assertIsNone(t_publisher.get_exception()) + self.assertFalse(t_publisher.is_running) + self.assertTrue(t_publisher.is_done) + + @timeout_decorator.timeout(5) + def test_publish(self): + """Test publishing a message and having a subscriber receive it""" + t_msg = "test" + t_topic = b"A" + + with ZeroMQPublisher(self.DEFAULT_PUBLISH_ADDRESS, [t_topic]) as t_publisher: + self.wait_for_start(t_publisher) + + ctx = zmq.Context.instance() + + t_connected = Value(c_int8, 0, lock=False) + + self.create_event_monitor( + t_publisher.publisher.get_monitor_socket(), t_connected + ) + + subscribe = ctx.socket(zmq.SUB) + subscribe.connect(self.DEFAULT_SUBSCRIBE_ADDRESS) + subscribe.setsockopt(zmq.SUBSCRIBE, t_topic) + + while t_connected.value < 1: + logger.info("Waiting for topic subscription..") + time.sleep(0.1) + + for _ in range( + 0, 5 + ): # check against accidental shutdown after first message + t_publisher.send(t_msg) + msg = subscribe.recv_multipart() + self.assertIsInstance(datetime.fromisoformat(msg[1].decode()), datetime) + self.assertEqual(t_msg.encode(), msg[2]) + + subscribe.close() + + while t_connected.value != 0: + logger.info("Waiting for subscriber to disconnect") + time.sleep(0.1) + + self.assertTrue(t_publisher.is_running) + self.assertFalse(t_publisher.is_stopping) + self.assertFalse(t_publisher.is_done) + + def test_publish_huge_message_multi_subscriber(self): + test_data = self.load_test_json() + t_topic = b"A" + + with ZeroMQPublisher(self.DEFAULT_PUBLISH_ADDRESS, [t_topic]) as t_publisher: + self.wait_for_start(t_publisher) + + ctx = zmq.Context.instance() + + t_connected = Value(c_int8, 0, lock=False) + + self.create_event_monitor( + t_publisher.publisher.get_monitor_socket(), t_connected + ) + + subscribers = [] + for _ in range(0, 2): + subscribe = ctx.socket(zmq.SUB) + subscribe.connect(self.DEFAULT_SUBSCRIBE_ADDRESS) + subscribe.setsockopt(zmq.SUBSCRIBE, t_topic) + subscribers.append(subscribe) + + while t_connected.value < 2: + logger.info("Waiting for topic subscriptions..") + time.sleep(0.1) + + t_publisher.send(test_data) + for i in range(0, 2): + msg = subscribers[i].recv_multipart() + self.assertIsInstance(datetime.fromisoformat(msg[1].decode()), datetime) + self.assertEqual(test_data.encode(), msg[2]) + + def test_callback(self): + """Test that triggering done callbacks works""" + + t_topic = b"A" + + with ZeroMQPublisher(self.DEFAULT_PUBLISH_ADDRESS, [t_topic]) as t_publisher: + self.wait_for_start(t_publisher) + + t_cb = mock.Mock() + + t_publisher.register_callback(t_cb) + + t_publisher.shutdown() + + while not t_publisher.is_stopping or not t_publisher.is_done: + logger.info("Waiting for publisher thread to stop..") + time.sleep(0.1) + + t_cb.assert_called_once() + + def test_queue(self): + """Test queuing of messages and full exception""" + t_queue_size = 10 + t_topic = b"A" + t_publisher = ZeroMQPublisher( + bind_uri=self.DEFAULT_PUBLISH_ADDRESS, + topics=[t_topic], + queue_size=t_queue_size, + ) + t_publisher.shutdown() + + while not t_publisher.is_stopping or not t_publisher.is_done: + logger.info("Waiting for publisher thread to stop..") + time.sleep(0.1) + + self.assertEqual(t_queue_size, t_publisher.queue_size) + + for i in range(1, t_queue_size + 1): + t_publisher.send("hello") + self.assertEqual(i, t_publisher.queue_fill) + + self.assertRaises(queue.Full, t_publisher.send, "hello") diff --git a/tests/zeromq/test_subscriber.py b/tests/zeromq/test_subscriber.py new file mode 100644 index 0000000000000000000000000000000000000000..1718242efb1e94ddb2c269344696c09ad22b7e97 --- /dev/null +++ b/tests/zeromq/test_subscriber.py @@ -0,0 +1,194 @@ +# Copyright (C) 2025 ASTRON (Netherlands Institute for Radio Astronomy) +# SPDX-License-Identifier: Apache-2.0 + +"""ZeroMQSubscriber test classes""" + +import asyncio +import logging +import time +from datetime import datetime +from threading import Thread +from unittest import TestCase + +from timeout_decorator import timeout_decorator + +from lofar_lotus.zeromq import ZeroMQPublisher +from lofar_lotus.zeromq import ( + ZeroMQSubscriber, + AsyncZeroMQSubscriber, +) + +logger = logging.getLogger() + + +class TestSubscriber(TestCase): + DEFAULT_PUBLISH_ADDRESS = "tcp://*:6001" + DEFAULT_SUBSCRIBE_ADDRESS = "tcp://127.0.0.1:6001" + + @staticmethod + def wait_for_start(publisher: ZeroMQPublisher): + """Spin until publisher is running""" + while not publisher.is_running: + logger.info("Waiting for publisher thread to start..") + time.sleep(0.1) + + @staticmethod + def wait_for_connect(subscriber: ZeroMQSubscriber): + """Spin until subscriber is running""" + while not subscriber.is_connected: + logger.info("Waiting for subscriber thread to start..") + time.sleep(0.1) + + @staticmethod + def wait_for_disconnect(subscriber: ZeroMQSubscriber): + """Spin until subscriber is running""" + while subscriber.is_connected: + logger.info("Waiting for subscriber thread to stop..") + time.sleep(0.1) + + @timeout_decorator.timeout(5) + def test_recv(self): + """Test receiving a message""" + t_msg = "test" + t_topic = "topic" + + with ZeroMQPublisher(self.DEFAULT_PUBLISH_ADDRESS, [t_topic]) as publisher: + self.wait_for_start(publisher) + + with ZeroMQSubscriber( + self.DEFAULT_SUBSCRIBE_ADDRESS, [t_topic] + ) as t_subscriber: + self.wait_for_connect(t_subscriber) + + for _ in range( + 0, 5 + ): # check against accidental shutdown after first message + publisher.send(t_msg) + _, timestamp, msg = t_subscriber.recv() + self.assertIsInstance(timestamp, datetime) + self.assertEqual(t_msg, msg) + + @timeout_decorator.timeout(5) + def test_connects_disconnects(self): + """Test connect/disconnect information.""" + t_topic = "topic" + + with ZeroMQSubscriber( + self.DEFAULT_SUBSCRIBE_ADDRESS, [t_topic] + ) as t_subscriber: + self.assertFalse(t_subscriber.is_connected) + self.assertEqual(0, t_subscriber.nr_connects) + self.assertEqual(0, t_subscriber.nr_disconnects) + + with ZeroMQPublisher(self.DEFAULT_PUBLISH_ADDRESS, [t_topic]) as publisher: + self.wait_for_start(publisher) + + self.wait_for_connect(t_subscriber) + + self.assertTrue(t_subscriber.is_connected) + self.assertEqual(1, t_subscriber.nr_connects) + self.assertEqual(0, t_subscriber.nr_disconnects) + + self.wait_for_disconnect(t_subscriber) + + self.assertFalse(t_subscriber.is_connected) + self.assertEqual(1, t_subscriber.nr_connects) + self.assertEqual(1, t_subscriber.nr_disconnects) + + +class TestAsyncSubscriber(TestCase): + """AsyncZeroMQSubscriber test class""" + + DEFAULT_PUBLISH_ADDRESS = "tcp://*:6001" + DEFAULT_SUBSCRIBE_ADDRESS = "tcp://127.0.0.1:6001" + + def setUp(self): + """Test setup""" + + def run_event_loop(loop): + asyncio.set_event_loop(loop) + loop.run_forever() + + self.event_loop = asyncio.new_event_loop() + thread = Thread(target=run_event_loop, args=(self.event_loop,), daemon=True) + thread.start() + + @staticmethod + def wait_for_start(publisher: ZeroMQPublisher): + """Spin until publisher is running""" + while not publisher.is_running: + logger.info("Waiting for publisher thread to start..") + time.sleep(0.1) + + @staticmethod + async def wait_for_connect(subscriber: ZeroMQSubscriber): + """Spin until subscriber is running""" + while not subscriber.is_connected: + logger.info("Waiting for subscriber thread to start..") + await asyncio.sleep(0.1) + + @staticmethod + async def wait_for_disconnect(subscriber: ZeroMQSubscriber): + """Spin until subscriber is running""" + while subscriber.is_connected: + logger.info("Waiting for subscriber thread to stop..") + await asyncio.sleep(0.1) + + @timeout_decorator.timeout(5) + def test_connects_disconnects(self): + """Test connect/disconnect information.""" + t_topic = "topic" + + async def run_subscriber(): + async with AsyncZeroMQSubscriber( + self.DEFAULT_SUBSCRIBE_ADDRESS, [t_topic] + ) as t_subscriber: + self.assertFalse(t_subscriber.is_connected) + self.assertEqual(0, t_subscriber.nr_connects) + self.assertEqual(0, t_subscriber.nr_disconnects) + + with ZeroMQPublisher( + self.DEFAULT_PUBLISH_ADDRESS, [t_topic] + ) as publisher: + self.wait_for_start(publisher) + + await self.wait_for_connect(t_subscriber) + + self.assertTrue(t_subscriber.is_connected) + self.assertEqual(1, t_subscriber.nr_connects) + self.assertEqual(0, t_subscriber.nr_disconnects) + + await self.wait_for_disconnect(t_subscriber) + + self.assertFalse(t_subscriber.is_connected) + self.assertEqual(1, t_subscriber.nr_connects) + self.assertEqual(1, t_subscriber.nr_disconnects) + + future = asyncio.run_coroutine_threadsafe(run_subscriber(), self.event_loop) + _ = future.result() + + @timeout_decorator.timeout(5) + def test_async_recv(self): + """Test receiving a message""" + t_msg = "test" + t_topic = "topic" + + with ZeroMQPublisher(self.DEFAULT_PUBLISH_ADDRESS, [t_topic]) as publisher: + self.wait_for_start(publisher) + + async def run_subscriber(): + async with AsyncZeroMQSubscriber( + self.DEFAULT_SUBSCRIBE_ADDRESS, [t_topic] + ) as t_subscriber: + await self.wait_for_connect(t_subscriber) + + for _ in range( + 0, 5 + ): # check against accidental shutdown after first message + publisher.send(t_msg) + _, timestamp, msg = await t_subscriber.async_recv() + self.assertIsInstance(timestamp, datetime) + self.assertEqual(t_msg, msg) + + future = asyncio.run_coroutine_threadsafe(run_subscriber(), self.event_loop) + _ = future.result()