# Copyright (C) 2025 ASTRON (Netherlands Institute for Radio Astronomy) # SPDX-License-Identifier: Apache-2.0 import copy from datetime import datetime, timedelta, timezone from typing import Dict from typing import List from unittest import mock from tango import DevFailed from tangostationcontrol.observation import observation_controller as obs_module from tangostationcontrol.test.dummy_observation_settings import ( get_observation_settings_two_fields_core, ) from tests import base @mock.patch("tango.Util.instance") class TestObservationController(base.TestCase): """Test Observation Controller main operations""" def generate_observation( self, running: bool, antenna_fields: list[str] = None ) -> obs_module.Observation: if not antenna_fields: antenna_fields = ["HBA"] m_obs_field = self.m_observation(tango_domain="DMR", parameters={}) m_obs_field.is_running.return_value = running m_obs_field.antenna_fields = antenna_fields return copy.deepcopy(m_obs_field) def setUp(self): super().setUp() proxy_patcher = self.proxy_patch(obs_module, "Observation", autospec=True) self.m_observation = proxy_patcher["mock"] def observation_setup( self, observations: Dict[int, obs_module.Observation], ) -> obs_module.ObservationController: sut = obs_module.ObservationController("DMR") self.assertListEqual([], sut.running_observations) for obs_id, observation in observations.items(): sut[obs_id] = observation return sut def observations_running( self, observations: Dict[int, obs_module.Observation], expected_result: List[int], ): sut = self.observation_setup(observations) self.assertListEqual(expected_result, sut.running_observations) def test_observations_running(self, _m_tango_util): self.observations_running( observations={1: self.generate_observation(True)}, expected_result=[1] ) def test_observations_not_running(self, _m_tango_util): self.observations_running({2: self.generate_observation(False)}, []) def test_observations_running_mix(self, _m_tango_util): self.observations_running( { 1: self.generate_observation(True), 2: self.generate_observation(False), 3: self.generate_observation(True), 5: self.generate_observation(False), }, [1, 3], ) def test_observations_running_exception(self, _m_tango_util): m_obs_field1 = self.m_observation(tango_domain="DMR", parameters={}) m_obs_field2 = self.m_observation(tango_domain="DMR", parameters={}) # Will influence both obs_field objects... m_obs_field1.is_running.side_effect = [DevFailed, True] sut = obs_module.ObservationController("DMR") sut[1] = m_obs_field1 sut[2] = m_obs_field2 self.assertListEqual([2], sut.running_observations) def observations_antenna_fields( self, observations: Dict[int, obs_module.Observation], expected_result: List[str], ): sut = self.observation_setup(observations) self.assertListEqual(expected_result, sut.active_antenna_fields) def test_active_field_single(self, _m_tango_util): self.observations_antenna_fields( {1: self.generate_observation(True, ["HBA"])}, ["HBA"] ) def test_active_field_mix_multi(self, _m_tango_util): self.observations_antenna_fields( { 1: self.generate_observation(False, ["HBA"]), 2: self.generate_observation(True, ["HBA0"]), 3: self.generate_observation(True, ["LBA"]), 5: self.generate_observation(False, ["LBA"]), }, ["HBA0", "LBA"], ) def test_active_field_exception(self, _m_tango_util): m_obs_field1 = self.m_observation(tango_domain="DMR", parameters={}) m_obs_field1.antenna_fields = ["HBA"] m_obs_field2 = self.m_observation(tango_domain="DMR", parameters={}) m_obs_field1.antenna_fields = ["LBA"] # Will influence both obs_field objects... m_obs_field1.is_running.side_effect = [DevFailed, True] sut = obs_module.ObservationController("DMR") sut[1] = m_obs_field1 sut[2] = m_obs_field2 self.assertListEqual(["LBA"], sut.active_antenna_fields) def test_add_observation(self, _m_tango_util): settings = get_observation_settings_two_fields_core() for antenna_field in settings.antenna_fields: antenna_field.stop_time = ( datetime.now(timezone.utc) + timedelta(days=1) ).isoformat() sut = obs_module.ObservationController("DMR") self.m_observation.return_value.observation_id = 5 sut.add_observation(settings) self.m_observation.assert_called_once_with( tango_domain="DMR", parameters=settings, start_callback=sut._start_callback, stop_callback=sut._internal_stop_callback, ) self.m_observation.return_value.create_devices.assert_called_once() self.m_observation.return_value.initialise_observation.assert_called_once() self.m_observation.return_value.update.assert_called_once() self.assertEqual( sut[self.m_observation.return_value.observation_id], self.m_observation.return_value, ) def test_stop_callback(self, _m_tango_util): """Test that the _stop_callback correctly cleans up observations""" settings = get_observation_settings_two_fields_core() for antenna_field in settings.antenna_fields: antenna_field.stop_time = (datetime.now() + timedelta(days=1)).isoformat() self.m_observation.return_value.observation_id = 5 sut = obs_module.ObservationController("DMR") sut.add_observation(settings) self.assertEqual(1, len(sut)) sut._internal_stop_callback(settings.antenna_fields[0].observation_id) self.assertEqual(0, len(sut)) def test_add_observation_failed(self, _m_tango_util): settings = get_observation_settings_two_fields_core() for antenna_field in settings.antenna_fields: antenna_field.stop_time = (datetime.now() + timedelta(days=1)).isoformat() sut = obs_module.ObservationController("DMR") self.m_observation.return_value.observation_id = 5 self.m_observation.return_value.create_devices.side_effect = [DevFailed] self.assertRaises(RuntimeError, sut.add_observation, settings) self.m_observation.return_value.destroy_devices.assert_called_once() def test_start_observation(self, _m_tango_util): sut = obs_module.ObservationController("DMR") sut[5] = mock.Mock() sut.start_observation(5) sut[5].start.assert_called_once() def test_start_observation_key_error(self, _m_tango_util): sut = obs_module.ObservationController("DMR") self.assertRaises(KeyError, sut.start_observation, "12554812435") def test_stop_observation(self, _m_tango_util): sut = obs_module.ObservationController("DMR") m_observation = mock.Mock() sut[5] = m_observation self.assertTrue(5 in sut) sut.stop_observation_now(5) m_observation.stop.assert_called_once() self.assertFalse(5 in sut) def test_stop_observation_key_error(self, _m_tango_util): sut = obs_module.ObservationController("DMR") self.assertRaises(KeyError, sut.stop_observation_now, "12554812435") def test_stop_all_observations_now_no_running(self, _m_tango_util): sut = obs_module.ObservationController("DMR") sut._destroy_all_observation_field_devices = mock.Mock() sut.stop_all_observations_now() sut._destroy_all_observation_field_devices.assert_called_once() def test_stop_all_observations_now_running(self, _m_tango_util): sut = obs_module.ObservationController("DMR") sut._destroy_all_observation_field_devices = mock.Mock() sut.stop_observation_now = mock.Mock() sut[5] = mock.Mock() sut[9] = mock.Mock() sut.stop_all_observations_now() sut._destroy_all_observation_field_devices.assert_called_once() self.assertEqual(((5,),), sut.stop_observation_now.call_args_list[0]) self.assertEqual(((9,),), sut.stop_observation_now.call_args_list[1]) @mock.patch.object(obs_module, "Database") def test_destroy_all_observation_field_devices_errors( self, m_database, _m_tango_util ): """Test that all exported devices for the class are destroyed""" devices = [mock.Mock(), mock.Mock()] m_database.return_value.get_device_exported_for_class.return_value = devices m_database.return_value.delete_device.side_effect = [Exception] obs_module.ObservationController._destroy_all_observation_field_devices() self.assertEqual( ((devices[0],),), m_database.return_value.delete_device.call_args_list[0] ) self.assertEqual( ((devices[1],),), m_database.return_value.delete_device.call_args_list[1] )