#  Copyright (C) 2025 ASTRON (Netherlands Institute for Radio Astronomy)
#  SPDX-License-Identifier: Apache-2.0

import copy
from datetime import datetime, timedelta, timezone
from typing import Dict
from typing import List
from unittest import mock

from tango import DevFailed
from tangostationcontrol.observation import observation_controller as obs_module
from tangostationcontrol.test.dummy_observation_settings import (
    get_observation_settings_two_fields_core,
)

from tests import base


@mock.patch("tango.Util.instance")
class TestObservationController(base.TestCase):
    """Test Observation Controller main operations"""

    def generate_observation(
        self, running: bool, antenna_fields: list[str] = None
    ) -> obs_module.Observation:
        if not antenna_fields:
            antenna_fields = ["HBA"]

        m_obs_field = self.m_observation(tango_domain="DMR", parameters={})
        m_obs_field.is_running.return_value = running
        m_obs_field.antenna_fields = antenna_fields
        return copy.deepcopy(m_obs_field)

    def setUp(self):
        super().setUp()
        proxy_patcher = self.proxy_patch(obs_module, "Observation", autospec=True)
        self.m_observation = proxy_patcher["mock"]

    def observation_setup(
        self,
        observations: Dict[int, obs_module.Observation],
    ) -> obs_module.ObservationController:
        sut = obs_module.ObservationController("DMR")
        self.assertListEqual([], sut.running_observations)
        for obs_id, observation in observations.items():
            sut[obs_id] = observation
        return sut

    def observations_running(
        self,
        observations: Dict[int, obs_module.Observation],
        expected_result: List[int],
    ):
        sut = self.observation_setup(observations)
        self.assertListEqual(expected_result, sut.running_observations)

    def test_observations_running(self, _m_tango_util):
        self.observations_running(
            observations={1: self.generate_observation(True)}, expected_result=[1]
        )

    def test_observations_not_running(self, _m_tango_util):
        self.observations_running({2: self.generate_observation(False)}, [])

    def test_observations_running_mix(self, _m_tango_util):
        self.observations_running(
            {
                1: self.generate_observation(True),
                2: self.generate_observation(False),
                3: self.generate_observation(True),
                5: self.generate_observation(False),
            },
            [1, 3],
        )

    def test_observations_running_exception(self, _m_tango_util):
        m_obs_field1 = self.m_observation(tango_domain="DMR", parameters={})
        m_obs_field2 = self.m_observation(tango_domain="DMR", parameters={})

        # Will influence both obs_field objects...
        m_obs_field1.is_running.side_effect = [DevFailed, True]

        sut = obs_module.ObservationController("DMR")
        sut[1] = m_obs_field1
        sut[2] = m_obs_field2

        self.assertListEqual([2], sut.running_observations)

    def observations_antenna_fields(
        self,
        observations: Dict[int, obs_module.Observation],
        expected_result: List[str],
    ):
        sut = self.observation_setup(observations)
        self.assertListEqual(expected_result, sut.active_antenna_fields)

    def test_active_field_single(self, _m_tango_util):
        self.observations_antenna_fields(
            {1: self.generate_observation(True, ["HBA"])}, ["HBA"]
        )

    def test_active_field_mix_multi(self, _m_tango_util):
        self.observations_antenna_fields(
            {
                1: self.generate_observation(False, ["HBA"]),
                2: self.generate_observation(True, ["HBA0"]),
                3: self.generate_observation(True, ["LBA"]),
                5: self.generate_observation(False, ["LBA"]),
            },
            ["HBA0", "LBA"],
        )

    def test_active_field_exception(self, _m_tango_util):
        m_obs_field1 = self.m_observation(tango_domain="DMR", parameters={})
        m_obs_field1.antenna_fields = ["HBA"]
        m_obs_field2 = self.m_observation(tango_domain="DMR", parameters={})
        m_obs_field1.antenna_fields = ["LBA"]

        # Will influence both obs_field objects...
        m_obs_field1.is_running.side_effect = [DevFailed, True]

        sut = obs_module.ObservationController("DMR")
        sut[1] = m_obs_field1
        sut[2] = m_obs_field2

        self.assertListEqual(["LBA"], sut.active_antenna_fields)

    def test_add_observation(self, _m_tango_util):
        settings = get_observation_settings_two_fields_core()
        for antenna_field in settings.antenna_fields:
            antenna_field.stop_time = (
                datetime.now(timezone.utc) + timedelta(days=1)
            ).isoformat()

        sut = obs_module.ObservationController("DMR")

        self.m_observation.return_value.observation_id = 5

        sut.add_observation(settings)

        self.m_observation.assert_called_once_with(
            tango_domain="DMR",
            parameters=settings,
            start_callback=sut._start_callback,
            stop_callback=sut._internal_stop_callback,
        )
        self.m_observation.return_value.create_devices.assert_called_once()
        self.m_observation.return_value.initialise_observation.assert_called_once()
        self.m_observation.return_value.update.assert_called_once()

        self.assertEqual(
            sut[self.m_observation.return_value.observation_id],
            self.m_observation.return_value,
        )

    def test_stop_callback(self, _m_tango_util):
        """Test that the _stop_callback correctly cleans up observations"""

        settings = get_observation_settings_two_fields_core()
        for antenna_field in settings.antenna_fields:
            antenna_field.stop_time = (datetime.now() + timedelta(days=1)).isoformat()

        self.m_observation.return_value.observation_id = 5

        sut = obs_module.ObservationController("DMR")
        sut.add_observation(settings)

        self.assertEqual(1, len(sut))

        sut._internal_stop_callback(settings.antenna_fields[0].observation_id)

        self.assertEqual(0, len(sut))

    def test_add_observation_failed(self, _m_tango_util):
        settings = get_observation_settings_two_fields_core()
        for antenna_field in settings.antenna_fields:
            antenna_field.stop_time = (datetime.now() + timedelta(days=1)).isoformat()

        sut = obs_module.ObservationController("DMR")

        self.m_observation.return_value.observation_id = 5
        self.m_observation.return_value.create_devices.side_effect = [DevFailed]

        self.assertRaises(RuntimeError, sut.add_observation, settings)

        self.m_observation.return_value.destroy_devices.assert_called_once()

    def test_start_observation(self, _m_tango_util):
        sut = obs_module.ObservationController("DMR")

        sut[5] = mock.Mock()

        sut.start_observation(5)

        sut[5].start.assert_called_once()

    def test_start_observation_key_error(self, _m_tango_util):
        sut = obs_module.ObservationController("DMR")

        self.assertRaises(KeyError, sut.start_observation, "12554812435")

    def test_stop_observation(self, _m_tango_util):
        sut = obs_module.ObservationController("DMR")

        m_observation = mock.Mock()
        sut[5] = m_observation

        self.assertTrue(5 in sut)

        sut.stop_observation_now(5)

        m_observation.stop.assert_called_once()

        self.assertFalse(5 in sut)

    def test_stop_observation_key_error(self, _m_tango_util):
        sut = obs_module.ObservationController("DMR")

        self.assertRaises(KeyError, sut.stop_observation_now, "12554812435")

    def test_stop_all_observations_now_no_running(self, _m_tango_util):
        sut = obs_module.ObservationController("DMR")
        sut._destroy_all_observation_field_devices = mock.Mock()

        sut.stop_all_observations_now()

        sut._destroy_all_observation_field_devices.assert_called_once()

    def test_stop_all_observations_now_running(self, _m_tango_util):
        sut = obs_module.ObservationController("DMR")
        sut._destroy_all_observation_field_devices = mock.Mock()
        sut.stop_observation_now = mock.Mock()

        sut[5] = mock.Mock()
        sut[9] = mock.Mock()

        sut.stop_all_observations_now()

        sut._destroy_all_observation_field_devices.assert_called_once()
        self.assertEqual(((5,),), sut.stop_observation_now.call_args_list[0])
        self.assertEqual(((9,),), sut.stop_observation_now.call_args_list[1])

    @mock.patch.object(obs_module, "Database")
    def test_destroy_all_observation_field_devices_errors(
        self, m_database, _m_tango_util
    ):
        """Test that all exported devices for the class are destroyed"""
        devices = [mock.Mock(), mock.Mock()]
        m_database.return_value.get_device_exported_for_class.return_value = devices

        m_database.return_value.delete_device.side_effect = [Exception]

        obs_module.ObservationController._destroy_all_observation_field_devices()

        self.assertEqual(
            ((devices[0],),), m_database.return_value.delete_device.call_args_list[0]
        )
        self.assertEqual(
            ((devices[1],),), m_database.return_value.delete_device.call_args_list[1]
        )