From e91e4b6bc33cded0d9c04514fdd55db060553f09 Mon Sep 17 00:00:00 2001
From: lukken <lukken@astron.nl>
Date: Mon, 12 Sep 2022 18:05:41 +0000
Subject: [PATCH] L2SS-877: Implement masked value merge for digital beam

---
 .../devices/lofar_device.py                   |  75 ++++++++++++-
 .../devices/sdp/digitalbeam.py                |  19 +++-
 .../devices/test_device_digitalbeam.py        |  62 ++++++++++-
 .../test/devices/test_digitalbeam_device.py   | 103 ++++++++++++++++++
 4 files changed, 250 insertions(+), 9 deletions(-)
 create mode 100644 tangostationcontrol/tangostationcontrol/test/devices/test_digitalbeam_device.py

diff --git a/tangostationcontrol/tangostationcontrol/devices/lofar_device.py b/tangostationcontrol/tangostationcontrol/devices/lofar_device.py
index c04033f34..baf2c9abc 100644
--- a/tangostationcontrol/tangostationcontrol/devices/lofar_device.py
+++ b/tangostationcontrol/tangostationcontrol/devices/lofar_device.py
@@ -11,14 +11,18 @@
 
 """
 
-# PyTango imports
-from tango.server import attribute, command, Device, DeviceMeta
-from tango import AttrWriteType, DevState, DebugIt, Attribute, DeviceProxy, AttrDataFormat, DevSource, DevDouble
+from collections.abc import Sequence
 import time
 import math
+from typing import List
+
 import numpy
 import textwrap
 
+# PyTango imports
+from tango.server import attribute, command, Device, DeviceMeta
+from tango import AttrWriteType, DevState, DebugIt, Attribute, DeviceProxy, AttrDataFormat, DevSource, DevDouble
+
 # Additional import
 from tangostationcontrol import __version__ as version
 from tangostationcontrol.clients.attribute_wrapper import attribute_wrapper
@@ -34,6 +38,14 @@ import logging
 logger = logging.getLogger()
 
 
+# TODO(Corne): Remove this in L2SS-940
+def sequence_not_str(obj):
+    """Separate sequences / collections from str, byte or bytearray"""
+
+    return (isinstance(obj, Sequence) or isinstance(obj, numpy.ndarray)) and not \
+        isinstance(obj, (str, bytes, bytearray))
+
+
 class lofar_device(Device, metaclass=DeviceMeta):
     """
 
@@ -92,6 +104,63 @@ class lofar_device(Device, metaclass=DeviceMeta):
 
         return self.get_state() in INITIALISED_STATES
 
+    # TODO(Corne): Actually implement this in L2SS-940
+    def atomic_read_modify_write_attribute(
+        self, values: List[any], proxy: DeviceProxy, attribute: str, sparse=None
+    ):
+        """Atomatically read-modify-write the attribute on the given proxy"""
+
+        current_values = proxy.read_attribute(attribute).value
+        logger.info("current_values")
+        logger.info(values)
+        self.merge_write(values, current_values, sparse)
+        # import pdb; pdb.set_trace()
+        proxy.write_attribute(values)
+
+    # TODO(Corne): Update docstring in L2SS-940
+    def merge_write(
+        self, merge_values: List[any], current_values: List[any], mask_or_sparse=None
+    ):
+        """Merge values as retrieved from :py:func:`~map_write` with current_values
+
+        This method will modify the contents of merge_values.
+
+        To be used by the :py:class:`~AntennaField` device to remove sparse fields
+        from mapped_values with recently retrieved current_values from RECV device.
+
+        :param merge_values: values as retrieved from :py:func:`~map_write`
+        :param current_values: values retrieved from RECV device on specific attribute
+        :param sparse: The value to identify sparse entries
+        """
+
+        if mask_or_sparse is not None and sequence_not_str(mask_or_sparse):
+            self._merge_write_mask(
+                merge_values, current_values, mask_or_sparse
+            )
+        else:
+            self._merge_write_delimiter(
+                merge_values, current_values, mask_or_sparse
+            )
+
+    def _merge_write_delimiter(
+        self, merge_values: List[any], current_values: List[any], sparse=None
+    ):
+        for idx, value in enumerate(merge_values):
+            if sequence_not_str(value):
+                self._merge_write_delimiter(merge_values[idx], current_values[idx], sparse)
+            elif value == sparse:
+                merge_values[idx] = current_values[idx]
+
+    def _merge_write_mask(
+        self, merge_values: List[any], current_values: List[any], mask: List[any]
+    ):
+        # import pdb; pdb.set_trace()
+        for idx, value in enumerate(merge_values):
+            if sequence_not_str(value):
+                self._merge_write_mask(merge_values[idx], current_values[idx], mask[idx])
+            elif not mask[idx]:
+                merge_values[idx] = current_values[idx]
+
     @log_exceptions()
     def init_device(self):
         """ Instantiates the device in the OFF state. """
diff --git a/tangostationcontrol/tangostationcontrol/devices/sdp/digitalbeam.py b/tangostationcontrol/tangostationcontrol/devices/sdp/digitalbeam.py
index 93c91ef79..376a98343 100644
--- a/tangostationcontrol/tangostationcontrol/devices/sdp/digitalbeam.py
+++ b/tangostationcontrol/tangostationcontrol/devices/sdp/digitalbeam.py
@@ -224,9 +224,6 @@ class DigitalBeam(beam_device):
         beam_weights = self.beamlet_proxy.calculate_bf_weights(fpga_delays.flatten())
         beam_weights = beam_weights.reshape((Beamlet.N_PN, Beamlet.A_PN * Beamlet.N_POL * Beamlet.N_BEAMLETS_CTRL))
 
-        # Filter out unwanted antennas (they get a weight of 0)
-        beam_weights *= self._map_inputs_on_polarised_inputs(self._input_select)
-
         return beam_weights
 
     @TimeIt()
@@ -234,8 +231,20 @@ class DigitalBeam(beam_device):
         """
         Uploads beam weights based on a given pointing direction 2D array (96 tiles x 3 parameters)
         """
-        # Write weights to SDP
-        self.beamlet_proxy.FPGA_bf_weights_xx_yy_RW = beam_weights
+
+        # import pdb; pdb.set_trace()
+
+        logger.info("beam weights")
+        logger.info(beam_weights)
+
+        logger.info("inputs")
+        logger.info(self._map_inputs_on_polarised_inputs(self._input_select))
+        self.atomic_read_modify_write_attribute(
+            beam_weights,
+            self.beamlet_proxy,
+            "FPGA_bf_weights_xx_yy_RW",
+            self._map_inputs_on_polarised_inputs(self._input_select)
+        )
 
         # Record where we now point to, now that we've updated the weights.
         # Only record pointings per beamlet, not which antennas took part
diff --git a/tangostationcontrol/tangostationcontrol/integration_test/default/devices/test_device_digitalbeam.py b/tangostationcontrol/tangostationcontrol/integration_test/default/devices/test_device_digitalbeam.py
index b06b60ce8..e15214ea4 100644
--- a/tangostationcontrol/tangostationcontrol/integration_test/default/devices/test_device_digitalbeam.py
+++ b/tangostationcontrol/tangostationcontrol/integration_test/default/devices/test_device_digitalbeam.py
@@ -75,10 +75,70 @@ class TestDeviceDigitalBeam(AbstractTestBases.TestDeviceBase):
         self.proxy.on()
 
         # Point to Zenith
-        self.proxy.set_pointing(numpy.array([["AZELGEO","0deg","90deg"]] * 488).flatten())
+        self.proxy.set_pointing(numpy.array([["AZELGEO", "0deg", "90deg"]] * 488).flatten())
 
         # beam weights should now be non-zero, we don't actually check their values for correctness
         self.assertNotEqual(0, sum(self.beamlet_proxy.FPGA_bf_weights_xx_yy_RW.flatten()))
+
+    def test_set_pointing_masked_enable(self):
+        """Verify that only selected inputs are written"""
+
+        self.setup_antennafield_proxy(self.antenna_qualities_ok, self.antenna_use_ok)
+        self.setup_sdp_proxy()
+        self.setup_recv_proxy()
+        # Setup beamlet configuration
+        self.beamlet_proxy.clock_RW = 200 * 1000000
+        self.beamlet_proxy.subband_select = list(range(488))
+
+        self.proxy.initialise()
+        self.proxy.Tracking_enabled_RW = False
+        self.proxy.on()
+
+        all_zeros = numpy.array([[0] * 5856] * 16)
+        self.beamlet_proxy.FPGA_bf_weights_xx_yy_RW = all_zeros
+
+        # Enable all inputs
+        self.proxy.input_select_RW = numpy.array([[True] * 488] * 96)
+
+        self.proxy.set_pointing(
+            numpy.array([["AZELGEO", "0deg", "90deg"]] * 488).flatten()
+        )
+
+        # Verify all zeros are replaced with other values for all inputs
+        self.assertTrue(numpy.any(numpy.not_equal(
+            all_zeros, self.beamlet_proxy.FPGA_bf_weights_xx_yy_RW
+        )))
+
+    def test_set_pointing_masked_disable(self):
+        """Verify that only diabled inputs are unchanged"""
+
+        self.setup_antennafield_proxy(self.antenna_qualities_ok, self.antenna_use_ok)
+        self.setup_sdp_proxy()
+        self.setup_recv_proxy()
+        # Setup beamlet configuration
+        self.beamlet_proxy.clock_RW = 200 * 1000000
+        self.beamlet_proxy.subband_select = list(range(488))
+
+        self.proxy.initialise()
+        self.proxy.Tracking_enabled_RW = False
+        self.proxy.on()
+
+        non_zeros = numpy.array([[16] * 5856] * 16)
+        self.beamlet_proxy.FPGA_bf_weights_xx_yy_RW = non_zeros
+
+        # Disable all inputs
+        # import pdb;
+        # pdb.set_trace()
+        self.proxy.input_select_RW = numpy.array([[False] * 488] * 96)
+
+        self.proxy.set_pointing(
+            numpy.array([["AZELGEO", "0deg", "90deg"]] * 488).flatten()
+        )
+
+        # Verify all zeros are replaced with other values for all inputs
+        numpy.testing.assert_equal(
+            non_zeros, self.beamlet_proxy.FPGA_bf_weights_xx_yy_RW
+        )
     
     def test_input_select_with_all_antennas_ok(self):
         """ Verify if input and antenna select are correctly calculated following Antennafield.Antenna_Usage_Mask """
diff --git a/tangostationcontrol/tangostationcontrol/test/devices/test_digitalbeam_device.py b/tangostationcontrol/tangostationcontrol/test/devices/test_digitalbeam_device.py
new file mode 100644
index 000000000..7bc2a5a7b
--- /dev/null
+++ b/tangostationcontrol/tangostationcontrol/test/devices/test_digitalbeam_device.py
@@ -0,0 +1,103 @@
+# -*- coding: utf-8 -*-
+#
+# This file is part of the LOFAR 2.0 Station Software
+#
+#
+#
+# Distributed under the terms of the APACHE license.
+# See LICENSE.txt for more info.
+
+import copy
+
+import numpy
+
+from tango.test_context import DeviceTestContext
+
+from tangostationcontrol.devices.sdp import digitalbeam
+
+from unittest import mock
+
+from tangostationcontrol.test.devices import device_base
+
+
+class TestDigitalBeamDevice(device_base.DeviceTestCase):
+
+    def setUp(self):
+        # DeviceTestCase setUp patches lofar_device DeviceProxy
+        super(TestDigitalBeamDevice, self).setUp()
+
+    @mock.patch.object(digitalbeam.DigitalBeam, "_wait_to_apply_weights")
+    @mock.patch.object(digitalbeam.DigitalBeam, "_compute_weights")
+    @mock.patch.object(digitalbeam, "DeviceProxy")
+    def test_apply_weights(self, m_proxy, m_compute, m_wait):
+        """Verify can overwrite digitalbeam data if input_selected"""
+
+        input_data = numpy.array([["AZELGEO", "0deg", "90deg"]] * 488).flatten()
+        current_data = numpy.array([[16384] * 5856] * 16)
+
+        m_proxy.return_value = mock.Mock(
+            read_attribute=mock.Mock(
+                return_value=mock.Mock(value=copy.copy(current_data))
+            ),
+            Antenna_Usage_Mask_R=numpy.array([0] * 96),
+            Antenna_Field_Reference_ITRF_R=mock.MagicMock(),
+            HBAT_reference_ITRF_R=numpy.array([[0] * 3] * 96)
+        )
+
+        new_data = numpy.array(
+            [[16384] * 2928 + [0] * 2928] * 16
+        )
+        m_compute.return_value = copy.copy(new_data)
+
+        with DeviceTestContext(
+            digitalbeam.DigitalBeam, process=False,
+        ) as proxy:
+            proxy.initialise()
+            proxy.Tracking_enabled_RW = False
+            proxy.input_select_RW = numpy.array([[False] * 488] * 96)
+
+            proxy.set_pointing(input_data)
+
+            # import pdb;
+            # pdb.set_trace()
+            numpy.testing.assert_equal(
+                m_proxy.return_value.write_attribute.call_args[0][0],
+                current_data
+            )
+
+    # @mock.patch.object(digitalbeam.DigitalBeam, "_wait_to_apply_weights")
+    # @mock.patch.object(digitalbeam.DigitalBeam, "_compute_weights")
+    # @mock.patch.object(digitalbeam, "DeviceProxy")
+    # def test_apply_weights(self, m_proxy, m_compute, m_wait):
+    #     """Verify can overwrite digitalbeam data if input_selected"""
+    #
+    #     input_data = numpy.array([["AZELGEO", "0deg", "90deg"]] * 488).flatten()
+    #     current_data = numpy.array([[16384] * 5856] * 16)
+    #
+    #     m_proxy.return_value = mock.Mock(
+    #         read_attribute=mock.Mock(
+    #             return_value=mock.Mock(value=current_data)
+    #         ),
+    #         Antenna_Usage_Mask_R=numpy.array([0] * 96),
+    #         Antenna_Field_Reference_ITRF_R=mock.MagicMock(),
+    #         HBAT_reference_ITRF_R=numpy.array([[0] * 3] * 96)
+    #     )
+    #
+    #     new_data = numpy.array(
+    #         [[16384] * 2928 + [0] * 2928] * 16
+    #     )
+    #     m_compute.return_value = copy.copy(new_data)
+    #
+    #     with DeviceTestContext(
+    #         digitalbeam.DigitalBeam, process=False,
+    #     ) as proxy:
+    #         proxy.initialise()
+    #         proxy.Tracking_enabled_RW = False
+    #         proxy.input_select_RW = numpy.array([[True] * 488] * 96)
+    #
+    #         proxy.set_pointing(input_data)
+    #
+    #         numpy.testing.assert_equal(
+    #             m_proxy.return_value.write_attribute.call_args[0][0],
+    #             new_data
+    #         )
-- 
GitLab