From 7e0cc2d268c77d36c1a2e34fa9388da51215072f Mon Sep 17 00:00:00 2001
From: lukken <lukken@astron.nl>
Date: Thu, 22 Sep 2022 11:24:12 +0000
Subject: [PATCH] L2SS-876: Fix proxy.write_attribute by numpy array casting

---
 .../common/type_checking.py                   | 12 +++++-
 .../devices/antennafield.py                   | 12 ++++--
 .../test/devices/test_antennafield_device.py  | 39 ++++++++++++++++++-
 3 files changed, 55 insertions(+), 8 deletions(-)

diff --git a/tangostationcontrol/tangostationcontrol/common/type_checking.py b/tangostationcontrol/tangostationcontrol/common/type_checking.py
index 25ed79556..e896e708c 100644
--- a/tangostationcontrol/tangostationcontrol/common/type_checking.py
+++ b/tangostationcontrol/tangostationcontrol/common/type_checking.py
@@ -8,8 +8,16 @@ from collections.abc import Sequence
 import numpy
 
 
+def is_sequence(obj):
+    """Identify sequences / collections"""
+    return isinstance(obj, Sequence) or isinstance(obj, numpy.ndarray)
+
+
 def sequence_not_str(obj):
     """Separate sequences / collections from str, byte or bytearray"""
+    return is_sequence(obj) and not isinstance(obj, (str, bytes, bytearray))
+
 
-    return (isinstance(obj, Sequence) or isinstance(obj, numpy.ndarray)) and not \
-        isinstance(obj, (str, bytes, bytearray))
+def type_not_sequence(obj):
+    """Separate sequences / collections from types"""
+    return not is_sequence(obj) and isinstance(obj, type)
diff --git a/tangostationcontrol/tangostationcontrol/devices/antennafield.py b/tangostationcontrol/tangostationcontrol/devices/antennafield.py
index b31ccdca8..93e1a8c67 100644
--- a/tangostationcontrol/tangostationcontrol/devices/antennafield.py
+++ b/tangostationcontrol/tangostationcontrol/devices/antennafield.py
@@ -17,6 +17,7 @@ from tango.server import device_property, attribute, command
 
 # Additional import
 from tangostationcontrol.common.type_checking import sequence_not_str
+from tangostationcontrol.common.type_checking import type_not_sequence
 from tangostationcontrol.common.entrypoint import entry
 from tangostationcontrol.devices.lofar_device import lofar_device
 from tangostationcontrol.common.lofar_logging import device_logging_to_python, log_exceptions
@@ -49,7 +50,10 @@ class mapped_attribute(attribute):
         if access == AttrWriteType.READ_WRITE:
             @fault_on_error()
             def write_func_wrapper(device, value):
-                write_func = device.set_mapped_attribute(mapping_attribute, value)
+                cast_type = dtype
+                while not type_not_sequence(cast_type):
+                    cast_type = cast_type[0]
+                write_func = device.set_mapped_attribute(mapping_attribute, value, cast_type)
 
             self.fset = write_func_wrapper
 
@@ -298,7 +302,7 @@ class AntennaField(lofar_device):
         antennas_auto_on   = numpy.logical_and(use == AntennaUse.AUTO, quality <= AntennaQuality.SUSPICIOUS)
 
         return numpy.logical_or(antennas_forced_on, antennas_auto_on)
-    
+
     def read_nr_antennas_R(self):
         # The number of antennas should be equal to:
         # * the number of elements in the Control_to_RECV_mapping (after reshaping),
@@ -397,7 +401,7 @@ class AntennaField(lofar_device):
 
         return mapped_values
     
-    def set_mapped_attribute(self, mapped_point: str, value):
+    def set_mapped_attribute(self, mapped_point: str, value, cast_type: type):
         """Set the attribute to new value only for controlled points
 
         :warning: This method is susceptible to a lost update race condition if the
@@ -414,7 +418,7 @@ class AntennaField(lofar_device):
             # TODO(Corne): Resolve potential lost update race condition
             current_values = recv_proxy.read_attribute(mapped_point).value
             self.__mapper.merge_write(new_values, current_values)
-            recv_proxy.write_attribute(mapped_point, new_values)
+            recv_proxy.write_attribute(mapped_point, new_values.astype(dtype=cast_type))
     
     # --------
     # Overloaded functions
diff --git a/tangostationcontrol/tangostationcontrol/test/devices/test_antennafield_device.py b/tangostationcontrol/tangostationcontrol/test/devices/test_antennafield_device.py
index a703e0860..556ddd93f 100644
--- a/tangostationcontrol/tangostationcontrol/test/devices/test_antennafield_device.py
+++ b/tangostationcontrol/tangostationcontrol/test/devices/test_antennafield_device.py
@@ -11,6 +11,9 @@ import time
 import statistics
 import logging
 
+import unittest
+from unittest import mock
+
 import numpy
 
 from tango.test_context import DeviceTestContext
@@ -388,8 +391,8 @@ class TestAntennafieldDevice(device_base.DeviceTestCase):
         'OPC_Server_Name': 'example.com',
         'OPC_Server_Port': 4840,
         'OPC_Time_Out': 5.0,
-        'Antenna_Field_Reference_ITRF' : [3.0, 3.0, 3.0],
-        'Antenna_Field_Reference_ETRS' : [7.0, 7.0, 7.0],
+        'Antenna_Field_Reference_ITRF': [3.0, 3.0, 3.0],
+        'Antenna_Field_Reference_ETRS': [7.0, 7.0, 7.0],
     }
 
     def setUp(self):
@@ -441,3 +444,35 @@ class TestAntennafieldDevice(device_base.DeviceTestCase):
        with DeviceTestContext(antennafield.AntennaField, properties={**self.AT_PROPERTIES, **antenna_properties}, process=True) as proxy:
         for i in range(len(antenna_names)):
             self.assertTrue(proxy.Antenna_Names_R[i]==f"C{i}")
+
+    @unittest.skip("Test for manual use, enable at most one (process=false)")
+    @mock.patch.object(antennafield, "DeviceProxy")
+    def test_set_mapped_attribute(self, m_proxy):
+        """Verify set_mapped_attribute only modifies controlled inputs"""
+
+        antenna_properties = {
+            'RECV_devices': ['stat/RECV/1'],
+        }
+
+        data = numpy.array([[False] * 32] * 96)
+
+        m_proxy.return_value = mock.Mock(
+            read_attribute=mock.Mock(
+                return_value=mock.Mock(value=data)
+            )
+        )
+
+        with DeviceTestContext(
+            antennafield.AntennaField, process=False,
+            properties={**self.AT_PROPERTIES, **antenna_properties}
+        ) as proxy:
+            proxy.boot()
+
+            import pdb; pdb.set_trace()
+            proxy.read_attribute('Antenna_Usage_Mask_R')
+            proxy.write_attribute("HBAT_PWR_on_RW", numpy.array([[False] * 32] * 48))
+
+            numpy.testing.assert_equal(
+                m_proxy.return_value.write_attribute.call_args[0][1],
+                data
+            )
-- 
GitLab