From 4a166cf2bee49920b9e8445b1aae33bf6822d382 Mon Sep 17 00:00:00 2001 From: Hannes Feldt <feldt@astron.nl> Date: Wed, 16 Apr 2025 18:51:26 +0000 Subject: [PATCH] Resolve L2SS-2050 "Add station state transition interface" --- .gitignore | 2 +- .gitlab-ci.yml | 4 +- README.md | 1 + infra/jobs/station/ec-sim.levant.nomad | 2 +- integration_tests/default/rpc/__init__.py | 2 + integration_tests/default/rpc/test_server.py | 36 ++ .../default/rpc/test_station_rpc.py | 168 +++++++++ pyproject.toml | 1 + requirements.txt | 3 +- sbin/install-hooks/submodule-and-lfs.sh | 5 +- tangostationcontrol/VERSION | 2 +- .../observation_field_settings.py | 1 + .../devices/base_device_classes/mapper.py | 6 +- tangostationcontrol/rpc/common.py | 53 ++- .../rpc/proxy/antennadeviceproxyfactory.py | 1 - tangostationcontrol/rpc/server.py | 30 +- tangostationcontrol/rpc/station.py | 165 +++++++++ tests/requirements.txt | 1 + tests/rpc/test_antenna.py | 54 ++- tests/rpc/test_server.py | 52 --- tests/rpc/test_station.py | 325 ++++++++++++++++++ 21 files changed, 809 insertions(+), 105 deletions(-) create mode 100644 integration_tests/default/rpc/__init__.py create mode 100644 integration_tests/default/rpc/test_server.py create mode 100644 integration_tests/default/rpc/test_station_rpc.py create mode 100644 tangostationcontrol/rpc/station.py delete mode 100644 tests/rpc/test_server.py create mode 100644 tests/rpc/test_station.py diff --git a/.gitignore b/.gitignore index 57c7e7cb0..43e4f25ae 100644 --- a/.gitignore +++ b/.gitignore @@ -55,4 +55,4 @@ infra/dev/nomad/tmp/* Lib/* Scripts/* pyvenv.cfg -bin/* \ No newline at end of file +bin/* diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 632779d94..5b3c88ad5 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -115,7 +115,7 @@ run_unit_tests: run_unit_tests_coverage: extends: .run_unit_test_version_base - needs: + needs: - trigger_prepare stage: test script: @@ -136,7 +136,7 @@ package_files: stage: package needs: - trigger_prepare - + artifacts: expire_in: 1w paths: diff --git a/README.md b/README.md index f596a833c..a1c1bce9e 100644 --- a/README.md +++ b/README.md @@ -150,6 +150,7 @@ Next change the version in the following places: through [https://git.astron.nl/lofar2.0/tango/-/tags](Deploy Tags) # Release Notes +* 0.49.0 Add Station service to control station state to gRPC server * 0.48.2 rename antennafield_id * 0.48.1 Fix exposing correct triangle of XSTs in gRPC service * 0.48.0 Add Antennafield to gRPC server diff --git a/infra/jobs/station/ec-sim.levant.nomad b/infra/jobs/station/ec-sim.levant.nomad index 0435b8654..ec0932fea 100644 --- a/infra/jobs/station/ec-sim.levant.nomad +++ b/infra/jobs/station/ec-sim.levant.nomad @@ -24,7 +24,7 @@ job "ec-sim" { } resources { cpu = 100 - memory = 100 + memory = 256 } } diff --git a/integration_tests/default/rpc/__init__.py b/integration_tests/default/rpc/__init__.py new file mode 100644 index 000000000..7ddb7c536 --- /dev/null +++ b/integration_tests/default/rpc/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2025 ASTRON (Netherlands Institute for Radio Astronomy) +# SPDX-License-Identifier: Apache-2.0 diff --git a/integration_tests/default/rpc/test_server.py b/integration_tests/default/rpc/test_server.py new file mode 100644 index 000000000..61ae6b4a8 --- /dev/null +++ b/integration_tests/default/rpc/test_server.py @@ -0,0 +1,36 @@ +# Copyright (C) 2025 ASTRON (Netherlands Institute for Radio Astronomy) +# SPDX-License-Identifier: Apache-2.0 + +import grpc +from grpc_reflection.v1alpha.proto_reflection_descriptor_database import ( + ProtoReflectionDescriptorDatabase, +) +from lofar_sid.interface.stationcontrol import antenna_pb2 +from lofar_sid.interface.stationcontrol import antenna_pb2_grpc + +from integration_tests import base + + +class TestServer(base.IntegrationTestCase): + def test_api(self): + """Check whether we actually expose the expected API.""" + + with grpc.insecure_channel("rpc.service.consul:50051") as channel: + reflection_db = ProtoReflectionDescriptorDatabase(channel) + services = reflection_db.get_services() + + self.assertIn("Observation", services) + self.assertIn("Antennafield", services) + self.assertIn("Statistics", services) + + def test_call(self): + """Test a basic gRPC call to the server.""" + + with grpc.insecure_channel("rpc.service.consul:50051") as channel: + stub = antenna_pb2_grpc.AntennaStub(channel) + + identifier = antenna_pb2.Identifier( + antennafield_name="lba0", + antenna_name="LBA00", + ) + _ = stub.GetAntenna(antenna_pb2.GetAntennaRequest(identifier=identifier)) diff --git a/integration_tests/default/rpc/test_station_rpc.py b/integration_tests/default/rpc/test_station_rpc.py new file mode 100644 index 000000000..3d101da9e --- /dev/null +++ b/integration_tests/default/rpc/test_station_rpc.py @@ -0,0 +1,168 @@ +# Copyright (C) 2025 ASTRON (Netherlands Institute for Radio Astronomy) +# SPDX-License-Identifier: Apache-2.0 + +""" +Station RPC module integration test +""" + +import logging + +import grpc +from lofar_sid.interface.stationcontrol import station_pb2 +from lofar_sid.interface.stationcontrol.station_pb2_grpc import StationStub +from tango._tango import DevState + +from integration_tests import base +from integration_tests.device_proxy import TestDeviceProxy + +logger = logging.getLogger() + + +class StationRPCTests(base.IntegrationTestCase): + """Integration Test class for station RPC methods""" + + stationmanager_name = "STAT/StationManager/1" + ec_name = "STAT/EC/1" + aps_l0_name = "STAT/APS/L0" + aps_l1_name = "STAT/APS/L1" + aps_h0_name = "STAT/APS/H0" + apsct_name = "STAT/APSCT/H0" + apspu_h0_name = "STAT/APSPU/H0" + apspu_l0_name = "STAT/APSPU/L0" + apspu_l1_name = "STAT/APSPU/L1" + ccd_name = "STAT/CCD/1" + pcon_name = "STAT/PCON/1" + sdp_name = "STAT/SDP/HBA0" + unb2_h0_name = "STAT/UNB2/H0" + unb2_l0_name = "STAT/UNB2/L0" + recvh_name = "STAT/RECVH/H0" + recvl_name = "STAT/RECVL/L0" + + def setUp(self): + self.station_name = "CS001" + + self.antennafield_name = "STAT/AFH/HBA0" + self.sdp_name = "STAT/SDP/HBA0" + self.sdpfirmware_name = "STAT/sdpfirmware/HBA0" + + self.setup_all_devices() + + host = "rpc.service.consul:50051" + try: + # connect to gRPC endpoint + channel = grpc.insecure_channel(host) + self._control_endpoint = StationStub(channel) + + except Exception as e: + self._control_endpoint = None + + logger.warning( + "Failed to connect to device on host %s: %s: %s", + host, + e.__class__.__name__, + e, + ) + + def setup_stationmanager_proxy(self): + """Initialise StationManager device""" + stationmanager_proxy = TestDeviceProxy(self.stationmanager_name) + # extend timeout for running commands, as state transitions can take a long time + stationmanager_proxy.set_timeout_millis(60000) + + stationmanager_proxy.off() + stationmanager_proxy.initialise() + stationmanager_proxy.on() + self.assertEqual(stationmanager_proxy.state(), DevState.ON) + return stationmanager_proxy + + def setup_proxy_off(self, device_name: str): + """Initialise proxy and turn off device""" + proxy = TestDeviceProxy(device_name) + proxy.off() + return proxy + + def setup_all_devices(self): + """Initialise all Tango devices needed for state transitions""" + self.stationmanager_proxy = self.setup_stationmanager_proxy() + + self.ec_proxy = self.setup_proxy_off(self.ec_name) + self.aps_l0_proxy = self.setup_proxy_off(self.aps_l0_name) + self.pcon_proxy = self.setup_proxy_off(self.pcon_name) + self.ccd_proxy = self.setup_proxy_off(self.ccd_name) + self.apspu_h0_proxy = self.setup_proxy_off(self.apspu_h0_name) + self.apspu_l0_proxy = self.setup_proxy_off(self.apspu_l0_name) + self.apsct_proxy = self.setup_proxy_off(self.apsct_name) + self.unb2_h0_proxy = self.setup_proxy_off(self.unb2_h0_name) + self.unb2_l0_proxy = self.setup_proxy_off(self.unb2_l0_name) + self.recvh_proxy = self.setup_proxy_off(self.recvh_name) + self.recvl_proxy = self.setup_proxy_off(self.recvl_name) + self.sdpfirmware_proxy = self.setup_proxy_off(self.sdpfirmware_name) + self.sdp_proxy = self.setup_proxy_off(self.sdp_name) + self.antennafield_proxy = self.setup_proxy_off(self.antennafield_name) + + async def test_station_state_off_to_on(self): + reply: station_pb2.StationStateReply = self._control_endpoint.GetStationState( + station_pb2.GetStationStateRequest() + ) + self.assertEqual(station_pb2.Station_State.OFF, reply.result.station_state) + replies = await self._control_endpoint.SetStationState( + station_pb2.SetStationStateRequest( + station_state=station_pb2.Station_State.ON + ) + ) + self.assertListEqual( + [ + station_pb2.Station_State.HIBERNATE, + station_pb2.Station_State.STANDBY, + station_pb2.Station_State.HIBERNATE, + station_pb2.Station_State.STANDBY, + station_pb2.Station_State.ON, + ], + [r.result.station_state for r in replies], + ) + reply: station_pb2.StationStateReply = self._control_endpoint.GetStationState( + station_pb2.GetStationStateRequest() + ) + self.assertEqual(station_pb2.Station_State.ON, reply.result.station_state) + + async def test_soft_station_reset(self): + reply: station_pb2.StationStateReply = self._control_endpoint.GetStationState( + station_pb2.GetStationStateRequest() + ) + self.assertEqual(station_pb2.Station_State.ON, reply.result.station_state) + replies = await self._control_endpoint.SoftStationReset( + station_pb2.SoftStationResetRequest() + ) + self.assertListEqual( + [ + station_pb2.Station_State.STANDBY, + station_pb2.Station_State.ON, + ], + [r.result.station_state for r in replies], + ) + reply: station_pb2.StationStateReply = self._control_endpoint.GetStationState( + station_pb2.GetStationStateRequest() + ) + self.assertEqual(station_pb2.Station_State.ON, reply.result.station_state) + + async def test_hard_station_reset(self): + reply: station_pb2.StationStateReply = self._control_endpoint.GetStationState( + station_pb2.GetStationStateRequest() + ) + self.assertEqual(station_pb2.Station_State.ON, reply.result.station_state) + replies = await self._control_endpoint.HardStationReset( + station_pb2.HardStationResetRequest() + ) + self.assertListEqual( + [ + station_pb2.Station_State.STANDBY, + station_pb2.Station_State.HIBERNATE, + station_pb2.Station_State.STANDBY, + station_pb2.Station_State.ON, + ], + [r.result.station_state for r in replies], + ) + reply: station_pb2.StationStateReply = self._control_endpoint.GetStationState( + station_pb2.GetStationStateRequest() + ) + self.assertEqual(station_pb2.Station_State.ON, reply.result.station_state) diff --git a/pyproject.toml b/pyproject.toml index 832377cfe..3db43c9eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ ignore = ["E203"] addopts = "--forked" markers = [ "timeout", + "parametrize", ] [tool.tox] diff --git a/requirements.txt b/requirements.txt index f7cfdf5c3..34e1a817b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -30,4 +30,5 @@ grpcio-tools # Apache 2 parse # MIT mergedeep # MIT getmac # MIT -python-dateutil # remove when on python 3.12 \ No newline at end of file +bidict # MPL 2 +python-dateutil # remove when on python 3.12 diff --git a/sbin/install-hooks/submodule-and-lfs.sh b/sbin/install-hooks/submodule-and-lfs.sh index afe39e2e3..7b7e90dba 100644 --- a/sbin/install-hooks/submodule-and-lfs.sh +++ b/sbin/install-hooks/submodule-and-lfs.sh @@ -1,10 +1,13 @@ #!/bin/bash -# Copyright (C) 2024 ASTRON (Netherlands Institute for Radio Astronomy) +# +# Copyright (C) 2025 ASTRON (Netherlands Institute for Radio Astronomy) # SPDX-License-Identifier: Apache-2.0 +# if [ ! -f "setup.sh" ]; then echo "submodule-and-lfs.sh must be executed with repository root as working directory!" exit 1 fi +mkdir -p .git/hooks/ cp bin/hooks/* .git/hooks/ diff --git a/tangostationcontrol/VERSION b/tangostationcontrol/VERSION index e85205d87..5c4503b70 100644 --- a/tangostationcontrol/VERSION +++ b/tangostationcontrol/VERSION @@ -1 +1 @@ -0.48.2 +0.49.0 diff --git a/tangostationcontrol/configuration/observation_field_settings.py b/tangostationcontrol/configuration/observation_field_settings.py index 91f693d01..0261355e7 100644 --- a/tangostationcontrol/configuration/observation_field_settings.py +++ b/tangostationcontrol/configuration/observation_field_settings.py @@ -12,6 +12,7 @@ from tangostationcontrol.configuration.sap import Sap from tangostationcontrol.configuration.sst import SST from tangostationcontrol.configuration.xst import XST + class ObservationFieldSettings(_ConfigurationBase): def __init__( self, diff --git a/tangostationcontrol/devices/base_device_classes/mapper.py b/tangostationcontrol/devices/base_device_classes/mapper.py index e7f6bc16f..6c984cb17 100644 --- a/tangostationcontrol/devices/base_device_classes/mapper.py +++ b/tangostationcontrol/devices/base_device_classes/mapper.py @@ -717,9 +717,9 @@ class RecvDeviceWalker: if recv <= 0: continue - recv_ant_masks[recv - 1][rcu_input // N_rcu_inp][ - rcu_input % N_rcu_inp - ] = True + recv_ant_masks[recv - 1][rcu_input // N_rcu_inp][rcu_input % N_rcu_inp] = ( + True + ) return recv_ant_masks diff --git a/tangostationcontrol/rpc/common.py b/tangostationcontrol/rpc/common.py index d99fa7f92..d15ceaf69 100644 --- a/tangostationcontrol/rpc/common.py +++ b/tangostationcontrol/rpc/common.py @@ -1,6 +1,7 @@ -# Copyright (C) 2025 ASTRON (Netherlands Institute for Radio Astronomy) -# SPDX-License-Identifier: Apache-2.0 +# Copyright (C) 2025 ASTRON (Netherlands Institute for Radio Astronomy) +# SPDX-License-Identifier: Apache-2.0 +import inspect import logging from functools import wraps @@ -38,24 +39,46 @@ def call_exception_metrics( metric_class=Counter, ) - @wraps(func) - def inner(*args, **kwargs): - try: - logger.info(f"gRPC function called: {func.__name__}") + if inspect.isasyncgenfunction(func): - call_count_metric.get_metric().inc() + @wraps(func) + async def inner(*args, **kwargs): + try: + logger.info(f"gRPC function called: {func.__name__}") - return func(*args, **kwargs) - except Exception as e: - exception_count_metric.get_metric().inc() + call_count_metric.get_metric().inc() + result = func(*args, **kwargs) + async for r in result: + yield r + except Exception as e: + exception_count_metric.get_metric().inc() - logger.exception( - f"gRPC function failed: {func.__name__} raised {e.__class__.__name__}: {e}" - ) + logger.exception( + f"gRPC function failed: {func.__name__} raised {e.__class__.__name__}: {e}" + ) - raise + raise - return inner + return inner + else: + + @wraps(func) + def inner(*args, **kwargs): + try: + logger.info(f"gRPC function called: {func.__name__}") + + call_count_metric.get_metric().inc() + return func(*args, **kwargs) + except Exception as e: + exception_count_metric.get_metric().inc() + + logger.exception( + f"gRPC function failed: {func.__name__} raised {e.__class__.__name__}: {e}" + ) + + raise + + return inner return wrapper diff --git a/tangostationcontrol/rpc/proxy/antennadeviceproxyfactory.py b/tangostationcontrol/rpc/proxy/antennadeviceproxyfactory.py index e7caa6a77..914b8fde2 100644 --- a/tangostationcontrol/rpc/proxy/antennadeviceproxyfactory.py +++ b/tangostationcontrol/rpc/proxy/antennadeviceproxyfactory.py @@ -8,7 +8,6 @@ logger = logging.getLogger() class AntennaDeviceProxyFactory: - @staticmethod def create_device_proxy_for_antennafield( antennafield_name: str, write_access: bool = False diff --git a/tangostationcontrol/rpc/server.py b/tangostationcontrol/rpc/server.py index 1647ba64b..f3343d9a9 100644 --- a/tangostationcontrol/rpc/server.py +++ b/tangostationcontrol/rpc/server.py @@ -1,7 +1,8 @@ -# Copyright (C) 2024 ASTRON (Netherlands Institute for Radio Astronomy) +# Copyright (C) 2025 ASTRON (Netherlands Institute for Radio Astronomy) # SPDX-License-Identifier: Apache-2.0 import argparse +import asyncio from concurrent import futures import logging import sys @@ -10,6 +11,8 @@ import grpc from grpc_reflection.v1alpha import reflection from lofar_sid.interface.stationcontrol import observation_pb2 from lofar_sid.interface.stationcontrol import observation_pb2_grpc +from lofar_sid.interface.stationcontrol import station_pb2 +from lofar_sid.interface.stationcontrol import station_pb2_grpc from lofar_sid.interface.stationcontrol import statistics_pb2 from lofar_sid.interface.stationcontrol import statistics_pb2_grpc from lofar_sid.interface.stationcontrol import antennafield_pb2 @@ -18,6 +21,7 @@ from lofar_sid.interface.stationcontrol import antenna_pb2 from lofar_sid.interface.stationcontrol import antenna_pb2_grpc from tangostationcontrol.rpc.observation import Observation +from tangostationcontrol.rpc.station import Station from tangostationcontrol.rpc.statistics import Statistics from tangostationcontrol.rpc.messagehandler import MultiEndpointZMQMessageHandler from tangostationcontrol.common.lofar_logging import configure_logger @@ -35,13 +39,14 @@ class Server: # Initialise gRPC server logger.info("Initialising grpc server") - self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + self.server = grpc.aio.server(futures.ThreadPoolExecutor(max_workers=10)) observation_pb2_grpc.add_ObservationServicer_to_server( Observation(), self.server ) antennafield_pb2_grpc.add_AntennafieldServicer_to_server( Antennafield(), self.server ) + station_pb2_grpc.add_StationServicer_to_server(Station(), self.server) statistics_pb2_grpc.add_StatisticsServicer_to_server( self.statistics_servicer, self.server ) @@ -51,6 +56,7 @@ class Server: observation_pb2.DESCRIPTOR.services_by_name["Observation"].full_name, antennafield_pb2.DESCRIPTOR.services_by_name["Antennafield"].full_name, antenna_pb2.DESCRIPTOR.services_by_name["Antenna"].full_name, + station_pb2.DESCRIPTOR.services_by_name["Station"].full_name, statistics_pb2.DESCRIPTOR.services_by_name["Statistics"].full_name, reflection.SERVICE_NAME, # reflection is required by innius-gpc-datasource ) @@ -62,14 +68,14 @@ class Server: def handle_statistics_message(self, topic, timestamp, message): self.statistics_servicer.handle_statistics_message(topic, timestamp, message) - def run(self): - self.server.start() + async def run(self): + await self.server.start() logger.info(f"Server running on port {self.port}") - self.server.wait_for_termination() + await self.server.wait_for_termination() - def stop(self): + async def stop(self): logger.info("Server stopping.") - self.server.stop(grace=1.0) + await self.server.stop(grace=1.0) logger.info("Server stopped.") @@ -108,9 +114,9 @@ def _create_parser(): return parser -def main(argv=None): +async def async_main(argv): parser = _create_parser() - args = parser.parse_args(argv or sys.argv[1:]) + args = parser.parse_args(argv) # Initialise simple subsystems configure_logger() @@ -132,7 +138,11 @@ def main(argv=None): last_message_cache.add_receiver(endpoint, [""]) # Serve indefinitely - server.run() + await server.run() + + +def main(argv=None): + asyncio.run(async_main(argv or sys.argv[1:])) if __name__ == "__main__": diff --git a/tangostationcontrol/rpc/station.py b/tangostationcontrol/rpc/station.py new file mode 100644 index 000000000..1a3a4ad5e --- /dev/null +++ b/tangostationcontrol/rpc/station.py @@ -0,0 +1,165 @@ +# Copyright (C) 2025 ASTRON (Netherlands Institute for Radio Astronomy) +# SPDX-License-Identifier: Apache-2.0 +import asyncio +import logging + +from bidict import bidict +from lofar_sid.interface.stationcontrol import station_pb2 +from lofar_sid.interface.stationcontrol import station_pb2_grpc + +from tangostationcontrol.states.station_state_enum import StationStateEnum +from tangostationcontrol.common.proxies.proxy import create_device_proxy +from tangostationcontrol.rpc.common import ( + call_exception_metrics, + reply_on_exception, +) + +logger = logging.getLogger() + + +class Station(station_pb2_grpc.StationServicer): + TRANSITION_SLEEP_S = 1 + STATE_MAP = bidict( + { + StationStateEnum.OFF: station_pb2.Station_State.OFF, + StationStateEnum.HIBERNATE: station_pb2.Station_State.HIBERNATE, + StationStateEnum.STANDBY: station_pb2.Station_State.STANDBY, + StationStateEnum.ON: station_pb2.Station_State.ON, + } + ) + VALID_PREDECESSOR_STATES = { + station_pb2.Station_State.OFF: [StationStateEnum.HIBERNATE], + station_pb2.Station_State.HIBERNATE: [ + StationStateEnum.OFF, + StationStateEnum.STANDBY, + ], + station_pb2.Station_State.STANDBY: [ + StationStateEnum.HIBERNATE, + StationStateEnum.ON, + ], + station_pb2.Station_State.ON: [StationStateEnum.STANDBY], + } + DESIRED_PREDECESSOR_STATE = { + station_pb2.Station_State.OFF: station_pb2.Station_State.HIBERNATE, + station_pb2.Station_State.HIBERNATE: station_pb2.Station_State.STANDBY, + station_pb2.Station_State.STANDBY: station_pb2.Station_State.HIBERNATE, + station_pb2.Station_State.ON: station_pb2.Station_State.STANDBY, + } + + @reply_on_exception(station_pb2.StationStateReply) + @call_exception_metrics("Station") + def GetStationState(self, request: station_pb2.GetStationStateRequest, context): + station_manager = create_device_proxy("STAT/StationManager/1") + return station_pb2.StationStateReply( + result=station_pb2.StationStateResult( + station_state=self.STATE_MAP[station_manager.station_state_r] + ) + ) + + @call_exception_metrics("Station") + async def SetStationState( + self, request: station_pb2.SetStationStateRequest, context + ): + station_manager = create_device_proxy("STAT/StationManager/1") + + while station_manager.station_state_transitioning_R: + await asyncio.sleep(self.TRANSITION_SLEEP_S) + + if ( + station_manager.station_state_r + == self.STATE_MAP.inverse[request.station_state] + ): + yield self.GetStationState(station_pb2.GetStationStateRequest(), context) + return + + if ( + station_manager.station_state_r == StationStateEnum.OFF + and request.station_state == station_pb2.Station_State.HIBERNATE + ): + """ + OFF -> HIBERNATE needs an additional HIBERNATE -> STANDBY -> HIBERNATE sequence, + since OFF usually means that station control was restarted and the actual hardware state is unknown. + If we transition to hibernate first, the next step will transition to standby and back to hibernate, + since hibernate is not a valid predecessor state. + """ + station_manager.set_timeout_millis( + station_manager.hibernate_transition_timeout_RW * 1000 + ) + station_manager.station_hibernate() + while station_manager.station_state_transitioning_R: + await asyncio.sleep(self.TRANSITION_SLEEP_S) + yield self.GetStationState(station_pb2.GetStationStateRequest(), context) + + if ( + station_manager.station_state_r + not in self.VALID_PREDECESSOR_STATES[request.station_state] + ): + async for transition in self.SetStationState( + station_pb2.SetStationStateRequest( + station_state=self.DESIRED_PREDECESSOR_STATE[request.station_state] + ), + context, + ): + yield transition + + match request.station_state: + case station_pb2.Station_State.OFF: + station_manager.station_off() + case station_pb2.Station_State.HIBERNATE: + station_manager.set_timeout_millis( + station_manager.hibernate_transition_timeout_RW * 1000 + ) + station_manager.station_hibernate() + case station_pb2.Station_State.STANDBY: + station_manager.set_timeout_millis( + station_manager.standby_transition_timeout_RW * 1000 + ) + station_manager.station_standby() + case station_pb2.Station_State.ON: + station_manager.set_timeout_millis( + station_manager.on_transition_timeout_RW * 1000 + ) + station_manager.station_on() + + while station_manager.station_state_transitioning_R: + await asyncio.sleep(self.TRANSITION_SLEEP_S) + + yield self.GetStationState(station_pb2.GetStationStateRequest(), context) + + @call_exception_metrics("Station") + async def SoftStationReset( + self, request: station_pb2.SoftStationResetRequest, context + ): + async for transition in self.SetStationState( + station_pb2.SetStationStateRequest( + station_state=station_pb2.Station_State.STANDBY + ), + context, + ): + yield transition + async for transition in self.SetStationState( + station_pb2.SetStationStateRequest( + station_state=station_pb2.Station_State.ON + ), + context, + ): + yield transition + + @call_exception_metrics("Station") + async def HardStationReset( + self, request: station_pb2.SoftStationResetRequest, context + ): + async for transition in self.SetStationState( + station_pb2.SetStationStateRequest( + station_state=station_pb2.Station_State.HIBERNATE + ), + context, + ): + yield transition + async for transition in self.SetStationState( + station_pb2.SetStationStateRequest( + station_state=station_pb2.Station_State.ON + ), + context, + ): + yield transition diff --git a/tests/requirements.txt b/tests/requirements.txt index a3026ac04..8aca66534 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -26,3 +26,4 @@ pytest>=7.3.0 # MIT pytest-forked>=1.6.0 # MIT pytest-cov >= 3.0.0 # MIT pytest-timeout # MIT +pytest-asyncio # Apache-2.0 diff --git a/tests/rpc/test_antenna.py b/tests/rpc/test_antenna.py index b08c8f95a..b14cdda62 100644 --- a/tests/rpc/test_antenna.py +++ b/tests/rpc/test_antenna.py @@ -9,7 +9,9 @@ from lofar_sid.interface.stationcontrol.antenna_pb2 import ( ) from tangostationcontrol.rpc.antenna import Antenna, AntennaNotFoundException -from tangostationcontrol.rpc.proxy.antennadeviceproxyfactory import AntennaDeviceProxyFactory +from tangostationcontrol.rpc.proxy.antennadeviceproxyfactory import ( + AntennaDeviceProxyFactory, +) from tests import base @@ -24,7 +26,9 @@ class TestAntenna(base.TestCase): ] return mock_db - @patch("tangostationcontrol.rpc.proxy.antennadeviceproxyfactory.create_device_proxy") + @patch( + "tangostationcontrol.rpc.proxy.antennadeviceproxyfactory.create_device_proxy" + ) @patch("tangostationcontrol.rpc.proxy.antennadeviceproxyfactory.tango.Database") def test_create_antenna_device_proxy_success( self, mock_tango_database, mock_create_device_proxy @@ -35,7 +39,7 @@ class TestAntenna(base.TestCase): self.mock_tango_db_response(mock_tango_database) - result = AntennaDeviceProxyFactory.create_device_proxy_for_antennafield( + result = AntennaDeviceProxyFactory.create_device_proxy_for_antennafield( identifier.antennafield_name, write_access=True ) @@ -45,8 +49,12 @@ class TestAntenna(base.TestCase): ) self.assertEqual(result, mock_device) - @patch("tangostationcontrol.rpc.proxy.antennadeviceproxyfactory.create_device_proxy") - @patch("tangostationcontrol.rpc.proxy.antennadeviceproxyfactory.tango.Database") # Mock tango Database + @patch( + "tangostationcontrol.rpc.proxy.antennadeviceproxyfactory.create_device_proxy" + ) + @patch( + "tangostationcontrol.rpc.proxy.antennadeviceproxyfactory.tango.Database" + ) # Mock tango Database def test_create_antenna_device_proxy_failure( self, mock_tango_database, mock_create_device_proxy ): @@ -56,7 +64,9 @@ class TestAntenna(base.TestCase): self.mock_tango_db_response(mock_tango_database) with self.assertRaises(IOError): - AntennaDeviceProxyFactory.create_device_proxy_for_antennafield(identifier.antennafield_name) + AntennaDeviceProxyFactory.create_device_proxy_for_antennafield( + identifier.antennafield_name + ) def test_get_antenna_index_found(self): # Arrange @@ -103,10 +113,14 @@ class TestAntenna(base.TestCase): self.assertEqual(reply.result.antenna_status, True) self.assertEqual(reply.result.antenna_use, 1) - #@patch("tangostationcontrol.rpc.antenna.create_device_proxy") - @patch("tangostationcontrol.rpc.proxy.antennadeviceproxyfactory.create_device_proxy") + # @patch("tangostationcontrol.rpc.antenna.create_device_proxy") + @patch( + "tangostationcontrol.rpc.proxy.antennadeviceproxyfactory.create_device_proxy" + ) @patch("tangostationcontrol.rpc.proxy.antennadeviceproxyfactory.tango.Database") - def test_set_antenna_status(self, mock_tango_database, mock_create_device_proxy): # ,mock_create_device_proxy_antenna + def test_set_antenna_status( + self, mock_tango_database, mock_create_device_proxy + ): # ,mock_create_device_proxy_antenna # Arrange self.mock_tango_db_response(mock_tango_database) @@ -129,19 +143,25 @@ class TestAntenna(base.TestCase): {"Antenna_Status": [1, 1]} ) - - @patch("tangostationcontrol.rpc.proxy.antennadeviceproxyfactory.create_device_proxy") + @patch( + "tangostationcontrol.rpc.proxy.antennadeviceproxyfactory.create_device_proxy" + ) @patch("tangostationcontrol.rpc.proxy.antennadeviceproxyfactory.tango.Database") - - def test_set_antenna_use(self, mock_tango_database, mock_create_device_proxy_factory ): # ,mock_create_device_proxy_antenna + def test_set_antenna_use( + self, mock_tango_database, mock_create_device_proxy_factory + ): # ,mock_create_device_proxy_antenna self.mock_tango_db_response(mock_tango_database) - - mock_create_device_proxy_factory.return_value.read_attribute.return_value.value = [1, 0] + mock_create_device_proxy_factory.return_value.read_attribute.return_value.value = [ + 1, + 0, + ] mock_create_device_proxy_factory.return_value.Antenna_Status_R = [1, 0] mock_create_device_proxy_factory.return_value.Antenna_Use_R = [1, 0] - mock_create_device_proxy_factory.return_value.Antenna_Names_R = ["Antenna0", "Antenna1"] - + mock_create_device_proxy_factory.return_value.Antenna_Names_R = [ + "Antenna0", + "Antenna1", + ] request = SetAntennaUseRequest( antenna_use=1, diff --git a/tests/rpc/test_server.py b/tests/rpc/test_server.py deleted file mode 100644 index 35d4ad3d5..000000000 --- a/tests/rpc/test_server.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright (C) 2025 ASTRON (Netherlands Institute for Radio Astronomy) -# SPDX-License-Identifier: Apache-2.0 - -from threading import Thread - -import grpc -from grpc_reflection.v1alpha.proto_reflection_descriptor_database import ( - ProtoReflectionDescriptorDatabase, -) - -from lofar_sid.interface.stationcontrol import antenna_pb2 -from lofar_sid.interface.stationcontrol import antenna_pb2_grpc -from tangostationcontrol.rpc.server import Server - -from tests import base - - -class TestServer(base.TestCase): - def setUp(self): - # Start gRPC server in a separate thread - self.server = Server(["LBA"], port=0) - self.server_thread = Thread(target=self.server.run) - self.server_thread.start() - - # Cleanup in the correct order (LIFO) - self.addCleanup(self.server_thread.join) - self.addCleanup(self.server.stop) - - def test_api(self): - """Check whether we actually expose the expected API.""" - - with grpc.insecure_channel(f"localhost:{self.server.port}") as channel: - reflection_db = ProtoReflectionDescriptorDatabase(channel) - services = reflection_db.get_services() - - self.assertIn("Observation", services) - self.assertIn("Antennafield", services) - self.assertIn("Statistics", services) - - def test_call(self): - """Test a basic gRPC call to the server.""" - - with grpc.insecure_channel(f"localhost:{self.server.port}") as channel: - stub = antenna_pb2_grpc.AntennaStub(channel) - - identifier = antenna_pb2.Identifier( - antennafield_name="lba", - antenna_name="LBA00", - ) - _ = stub.GetAntenna( - antenna_pb2.GetAntennaRequest(identifier=identifier) - ) diff --git a/tests/rpc/test_station.py b/tests/rpc/test_station.py new file mode 100644 index 000000000..5742ab7e0 --- /dev/null +++ b/tests/rpc/test_station.py @@ -0,0 +1,325 @@ +# Copyright (C) 2025 ASTRON (Netherlands Institute for Radio Astronomy) +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import patch + +import pytest +from lofar_sid.interface.stationcontrol import station_pb2 + +from tangostationcontrol.states.station_state_enum import StationStateEnum +from tangostationcontrol.rpc.station import Station + + +class TestStation: + @pytest.mark.parametrize( + "station_state,expected_state", + [ + (StationStateEnum.OFF, station_pb2.Station_State.OFF), + (StationStateEnum.HIBERNATE, station_pb2.Station_State.HIBERNATE), + (StationStateEnum.STANDBY, station_pb2.Station_State.STANDBY), + (StationStateEnum.ON, station_pb2.Station_State.ON), + ], + ) + @patch("tangostationcontrol.rpc.station.create_device_proxy") + def test_get_state(self, m_create_device_proxy, station_state, expected_state): + sut = Station() + + m_create_device_proxy.return_value.station_state_r = station_state + + # request/response + request = station_pb2.GetStationStateRequest() + reply = sut.GetStationState(request, None) + + # validate output + assert reply.result.station_state == expected_state + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "desired_state,current_state,expected_transitions", + [ + ( + station_pb2.Station_State.OFF, + StationStateEnum.OFF, + [station_pb2.Station_State.OFF], + ), + ( + station_pb2.Station_State.OFF, + StationStateEnum.HIBERNATE, + [station_pb2.Station_State.OFF], + ), + ( + station_pb2.Station_State.OFF, + StationStateEnum.STANDBY, + [station_pb2.Station_State.HIBERNATE, station_pb2.Station_State.OFF], + ), + ( + station_pb2.Station_State.OFF, + StationStateEnum.ON, + [ + station_pb2.Station_State.STANDBY, + station_pb2.Station_State.HIBERNATE, + station_pb2.Station_State.OFF, + ], + ), + ( + station_pb2.Station_State.HIBERNATE, + StationStateEnum.OFF, + [ + station_pb2.Station_State.HIBERNATE, + StationStateEnum.STANDBY, + station_pb2.Station_State.HIBERNATE, + ], + ), + ( + station_pb2.Station_State.HIBERNATE, + StationStateEnum.HIBERNATE, + [station_pb2.Station_State.HIBERNATE], + ), + ( + station_pb2.Station_State.HIBERNATE, + StationStateEnum.STANDBY, + [station_pb2.Station_State.HIBERNATE], + ), + ( + station_pb2.Station_State.HIBERNATE, + StationStateEnum.ON, + [ + station_pb2.Station_State.STANDBY, + station_pb2.Station_State.HIBERNATE, + ], + ), + ( + station_pb2.Station_State.STANDBY, + StationStateEnum.OFF, + [ + station_pb2.Station_State.HIBERNATE, + StationStateEnum.STANDBY, + station_pb2.Station_State.HIBERNATE, + station_pb2.Station_State.STANDBY, + ], + ), + ( + station_pb2.Station_State.STANDBY, + StationStateEnum.HIBERNATE, + [station_pb2.Station_State.STANDBY], + ), + ( + station_pb2.Station_State.STANDBY, + StationStateEnum.STANDBY, + [station_pb2.Station_State.STANDBY], + ), + ( + station_pb2.Station_State.STANDBY, + StationStateEnum.ON, + [station_pb2.Station_State.STANDBY], + ), + ( + station_pb2.Station_State.ON, + StationStateEnum.OFF, + [ + station_pb2.Station_State.HIBERNATE, + StationStateEnum.STANDBY, + station_pb2.Station_State.HIBERNATE, + station_pb2.Station_State.STANDBY, + station_pb2.Station_State.ON, + ], + ), + ( + station_pb2.Station_State.ON, + StationStateEnum.HIBERNATE, + [station_pb2.Station_State.STANDBY, station_pb2.Station_State.ON], + ), + ( + station_pb2.Station_State.ON, + StationStateEnum.STANDBY, + [station_pb2.Station_State.ON], + ), + ( + station_pb2.Station_State.ON, + StationStateEnum.ON, + [station_pb2.Station_State.ON], + ), + ], + ) + @patch("tangostationcontrol.rpc.station.create_device_proxy", autospec=True) + async def test_set_state( + self, m_create_device_proxy, desired_state, current_state, expected_transitions + ): + sut = Station() + + def transition_patch(wanted): + def wrp(): + m_create_device_proxy.return_value.station_state_r = wanted + + return wrp + + m_create_device_proxy.return_value.station_state_r = current_state + m_create_device_proxy.return_value.station_state_transitioning_R = False + m_create_device_proxy.return_value.station_off = transition_patch( + StationStateEnum.OFF + ) + m_create_device_proxy.return_value.station_hibernate = transition_patch( + StationStateEnum.HIBERNATE + ) + m_create_device_proxy.return_value.station_standby = transition_patch( + StationStateEnum.STANDBY + ) + m_create_device_proxy.return_value.station_on = transition_patch( + StationStateEnum.ON + ) + + # request/response + request = station_pb2.SetStationStateRequest(station_state=desired_state) + reply = sut.SetStationState(request, None) + + # validate output + transitions = [r.result.station_state async for r in reply] + assert transitions == expected_transitions + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "current_state,expected_transitions", + [ + ( + StationStateEnum.ON, + [ + StationStateEnum.STANDBY, + station_pb2.Station_State.ON, + ], + ), + ( + StationStateEnum.STANDBY, + [ + StationStateEnum.STANDBY, + station_pb2.Station_State.ON, + ], + ), + ( + StationStateEnum.HIBERNATE, + [ + station_pb2.Station_State.STANDBY, + station_pb2.Station_State.ON, + ], + ), + ( + StationStateEnum.OFF, + [ + station_pb2.Station_State.HIBERNATE, + station_pb2.Station_State.STANDBY, + station_pb2.Station_State.HIBERNATE, + station_pb2.Station_State.STANDBY, + station_pb2.Station_State.ON, + ], + ), + ], + ) + @patch("tangostationcontrol.rpc.station.create_device_proxy", autospec=True) + async def test_soft_reset( + self, m_create_device_proxy, current_state, expected_transitions + ): + sut = Station() + + def transition_patch(wanted): + def wrp(): + m_create_device_proxy.return_value.station_state_r = wanted + + return wrp + + m_create_device_proxy.return_value.station_state_r = current_state + m_create_device_proxy.return_value.station_state_transitioning_R = False + m_create_device_proxy.return_value.station_off = transition_patch( + StationStateEnum.OFF + ) + m_create_device_proxy.return_value.station_hibernate = transition_patch( + StationStateEnum.HIBERNATE + ) + m_create_device_proxy.return_value.station_standby = transition_patch( + StationStateEnum.STANDBY + ) + m_create_device_proxy.return_value.station_on = transition_patch( + StationStateEnum.ON + ) + + # request/response + request = station_pb2.SoftStationResetRequest() + reply = sut.SoftStationReset(request, None) + + # validate output + transitions = [r.result.station_state async for r in reply] + assert transitions == expected_transitions + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "current_state,expected_transitions", + [ + ( + StationStateEnum.ON, + [ + StationStateEnum.STANDBY, + station_pb2.Station_State.HIBERNATE, + station_pb2.Station_State.STANDBY, + station_pb2.Station_State.ON, + ], + ), + ( + StationStateEnum.STANDBY, + [ + station_pb2.Station_State.HIBERNATE, + station_pb2.Station_State.STANDBY, + station_pb2.Station_State.ON, + ], + ), + ( + StationStateEnum.HIBERNATE, + [ + station_pb2.Station_State.HIBERNATE, + station_pb2.Station_State.STANDBY, + station_pb2.Station_State.ON, + ], + ), + ( + StationStateEnum.OFF, + [ + station_pb2.Station_State.HIBERNATE, + station_pb2.Station_State.STANDBY, + station_pb2.Station_State.HIBERNATE, + station_pb2.Station_State.STANDBY, + station_pb2.Station_State.ON, + ], + ), + ], + ) + @patch("tangostationcontrol.rpc.station.create_device_proxy", autospec=True) + async def test_hard_reset( + self, m_create_device_proxy, current_state, expected_transitions + ): + sut = Station() + + def transition_patch(wanted): + def wrp(): + m_create_device_proxy.return_value.station_state_r = wanted + + return wrp + + m_create_device_proxy.return_value.station_state_r = current_state + m_create_device_proxy.return_value.station_state_transitioning_R = False + m_create_device_proxy.return_value.station_off = transition_patch( + StationStateEnum.OFF + ) + m_create_device_proxy.return_value.station_hibernate = transition_patch( + StationStateEnum.HIBERNATE + ) + m_create_device_proxy.return_value.station_standby = transition_patch( + StationStateEnum.STANDBY + ) + m_create_device_proxy.return_value.station_on = transition_patch( + StationStateEnum.ON + ) + + # request/response + request = station_pb2.SoftStationResetRequest() + reply = sut.HardStationReset(request, None) + + # validate output + transitions = [r.result.station_state async for r in reply] + assert transitions == expected_transitions -- GitLab