From 7e25e5e0ea40c478528d71d7933509fe79c4a670 Mon Sep 17 00:00:00 2001
From: lukken <lukken@astron.nl>
Date: Fri, 23 Sep 2022 15:00:35 +0000
Subject: [PATCH] L2SS-940: Consolidate duplicate code from 877 and 876

---
 tangostationcontrol/VERSION                   |  2 +-
 .../common/type_checking.py                   |  8 +-
 .../devices/antennafield.py                   | 26 +----
 .../tangostationcontrol/devices/boot.py       |  1 +
 .../devices/lofar_device.py                   | 98 +++++++++++++------
 .../default/statistics/test_writer_sst.py     |  2 +-
 .../test/common/test_type_checking.py         | 77 +++++++++++++++
 .../test/devices/test_antennafield_device.py  | 74 +++++++-------
 .../test/devices/test_lofar_device.py         | 14 +++
 9 files changed, 210 insertions(+), 92 deletions(-)
 create mode 100644 tangostationcontrol/tangostationcontrol/test/common/test_type_checking.py

diff --git a/tangostationcontrol/VERSION b/tangostationcontrol/VERSION
index ceab6e11e..6da28dde7 100644
--- a/tangostationcontrol/VERSION
+++ b/tangostationcontrol/VERSION
@@ -1 +1 @@
-0.1
\ No newline at end of file
+0.1.1
\ No newline at end of file
diff --git a/tangostationcontrol/tangostationcontrol/common/type_checking.py b/tangostationcontrol/tangostationcontrol/common/type_checking.py
index e896e708c..f6686b54b 100644
--- a/tangostationcontrol/tangostationcontrol/common/type_checking.py
+++ b/tangostationcontrol/tangostationcontrol/common/type_checking.py
@@ -9,15 +9,17 @@ import numpy
 
 
 def is_sequence(obj):
-    """Identify sequences / collections"""
+    """True for sequences, positionally ordered collections
+    See https://www.pythontutorial.net/advanced-python/python-sequences/
+    """
     return isinstance(obj, Sequence) or isinstance(obj, numpy.ndarray)
 
 
 def sequence_not_str(obj):
-    """Separate sequences / collections from str, byte or bytearray"""
+    """True for sequences that are not str, bytes or bytearray"""
     return is_sequence(obj) and not isinstance(obj, (str, bytes, bytearray))
 
 
 def type_not_sequence(obj):
-    """Separate sequences / collections from types"""
+    """True for types that are not sequences"""
     return not is_sequence(obj) and isinstance(obj, type)
diff --git a/tangostationcontrol/tangostationcontrol/devices/antennafield.py b/tangostationcontrol/tangostationcontrol/devices/antennafield.py
index 93e1a8c67..6479abf92 100644
--- a/tangostationcontrol/tangostationcontrol/devices/antennafield.py
+++ b/tangostationcontrol/tangostationcontrol/devices/antennafield.py
@@ -410,15 +410,15 @@ class AntennaField(lofar_device):
 
         """
 
+        # returns sparse multidimensional array, uncontrolled values set to None
         mapped_value = self.__mapper.map_write(mapped_point, value)
 
         for idx, recv_proxy in enumerate(self.recv_proxies):
             new_values = mapped_value[idx]
 
-            # TODO(Corne): Resolve potential lost update race condition
-            current_values = recv_proxy.read_attribute(mapped_point).value
-            self.__mapper.merge_write(new_values, current_values)
-            recv_proxy.write_attribute(mapped_point, new_values.astype(dtype=cast_type))
+            self.atomic_read_modify_write_attribute(
+                new_values, recv_proxy, mapped_point, cast_type=cast_type
+            )
     
     # --------
     # Overloaded functions
@@ -556,24 +556,6 @@ class AntennaToRecvMapper(object):
 
         return mapped_values
 
-    def merge_write(self, merge_values: List[any], current_values: List[any]):
-        """Merge values as retrieved from :py:func:`~map_write` with current_values
-
-        This method will modify the contents of merge_values.
-
-        To be used by the :py:class:`~AntennaField` device to remove None fields
-        from mapped_values with recently retrieved current_values from RECV device.
-
-        :param merge_values: values as retrieved from :py:func:`~map_write`
-        :param current_values: values retrieved from RECV device on specific attribute
-        """
-
-        for idx, value in enumerate(merge_values):
-            if sequence_not_str(value):
-                self.merge_write(merge_values[idx], current_values[idx])
-            elif value is None:
-                merge_values[idx] = current_values[idx]
-
     def _mapped_r_values(self, recv_results: List[any], default_values: List[any]):
         """Mapping for read using :py:attribute:`~_control_mapping` and shallow copy"""
 
diff --git a/tangostationcontrol/tangostationcontrol/devices/boot.py b/tangostationcontrol/tangostationcontrol/devices/boot.py
index 8d51d269f..42ce74d74 100644
--- a/tangostationcontrol/tangostationcontrol/devices/boot.py
+++ b/tangostationcontrol/tangostationcontrol/devices/boot.py
@@ -18,6 +18,7 @@ from tango import DebugIt
 from tango.server import command
 from tango.server import device_property, attribute
 from tango import AttrWriteType, DeviceProxy, DevState, DevSource
+
 # Additional import
 import numpy
 
diff --git a/tangostationcontrol/tangostationcontrol/devices/lofar_device.py b/tangostationcontrol/tangostationcontrol/devices/lofar_device.py
index 068c9cee6..f950ad142 100644
--- a/tangostationcontrol/tangostationcontrol/devices/lofar_device.py
+++ b/tangostationcontrol/tangostationcontrol/devices/lofar_device.py
@@ -10,8 +10,6 @@
 """Hardware Device Server for LOFAR2.0
 
 """
-
-from collections.abc import Sequence
 import time
 import math
 from typing import List
@@ -26,8 +24,10 @@ from tango import AttrWriteType, DevState, DebugIt, Attribute, DeviceProxy, Attr
 # Additional import
 from tangostationcontrol import __version__ as version
 from tangostationcontrol.clients.attribute_wrapper import attribute_wrapper
+
 from tangostationcontrol.common.lofar_logging import log_exceptions
 from tangostationcontrol.common.states import DEFAULT_COMMAND_STATES, INITIALISED_STATES
+from tangostationcontrol.common.type_checking import sequence_not_str
 from tangostationcontrol.devices.device_decorators import only_in_states, fault_on_error
 from tangostationcontrol.toolkit.archiver import Archiver
 
@@ -38,14 +38,6 @@ import logging
 logger = logging.getLogger()
 
 
-# TODO(Corne): Remove this in L2SS-940
-def sequence_not_str(obj):
-    """Separate sequences / collections from str, byte or bytearray"""
-
-    return (isinstance(obj, Sequence) or isinstance(obj, numpy.ndarray)) and not \
-        isinstance(obj, (str, bytes, bytearray))
-
-
 class lofar_device(Device, metaclass=DeviceMeta):
     """
 
@@ -106,54 +98,104 @@ class lofar_device(Device, metaclass=DeviceMeta):
 
     # TODO(Corne): Actually implement locking in L2SS-940
     def atomic_read_modify_write_attribute(
-        self, values: numpy.ndarray, proxy: DeviceProxy, attribute: str, sparse=None
+        self, values: numpy.ndarray, proxy: DeviceProxy, attribute: str,
+        mask_or_sparse=None, cast_type=None
     ):
-        """Atomatically read-modify-write the attribute on the given proxy"""
+        """Automatically read-modify-write the attribute on the given proxy
 
+        :param values: New values to write
+        :param proxy: Device to write the values to
+        :param attribute: Attribute of the device to write
+        :param mask_or_sparse: The value or mask used to replace elements in
+                               values with current attribute values
+        :param cast_type: type to cast numpy array to for delimited merge_writes
+
+        """
+
+        # proxy.lock()
         current_values = proxy.read_attribute(attribute).value
-        self.merge_write(values, current_values, sparse)
-        proxy.write_attribute(attribute, values)
+        merged_values = self.merge_write(
+            values, current_values, mask_or_sparse, cast_type
+        )
+        proxy.write_attribute(attribute, merged_values)
+        # proxy.unlock()
 
-    # TODO(Corne): Update docstring in L2SS-940
     def merge_write(
-        self, merge_values: List[any], current_values: List[any], mask_or_sparse=None
-    ):
-        """Merge values as retrieved from :py:func:`~map_write` with current_values
+        self, merge_values: numpy.ndarray, current_values: List[any],
+        mask_or_sparse=None, cast_type=None
+    ) -> numpy.ndarray:
+        """Merge values with current_values retrieved from attribute by mask / sparse
 
-        This method will modify the contents of merge_values.
+        To be used by the :py:func:`~atomic_read_modify_write_attribute`
 
-        To be used by the :py:class:`~AntennaField` device to remove sparse fields
-        from mapped_values with recently retrieved current_values from RECV device.
+        :param merge_values: New values to write and results of merge
+        :param current_values: values retrieved from an attribute
+        :param mask_or_sparse: The value or mask used to replace elements in
+                               merge_values with current_values
+        :param cast_type: type to cast numpy array to for delimited merge_writes
+        :return:
 
-        :param merge_values: values as retrieved from :py:func:`~map_write`
-        :param current_values: values retrieved from RECV device on specific attribute
-        :param sparse: The value to identify sparse entries
         """
 
+        # Create shallow copy of merge_values, use native numpy copy as it works
+        # for N dimensionality. (native copy.copy() only copies outermost dim)
+        merge_values = merge_values.copy()
+
         if mask_or_sparse is not None and sequence_not_str(mask_or_sparse):
             self._merge_write_mask(
                 merge_values, current_values, mask_or_sparse
             )
+            return merge_values
         else:
+            if cast_type is None:
+                raise AttributeError(
+                    "dtype can not be None for sparse merge_write"
+                )
+
             self._merge_write_delimiter(
                 merge_values, current_values, mask_or_sparse
             )
+            return merge_values.astype(dtype=cast_type)
 
     def _merge_write_delimiter(
-        self, merge_values: List[any], current_values: List[any], sparse=None
+        self, merge_values: numpy.ndarray, current_values: List[any],
+        sparse=None,
     ):
+        """Merge merge_values and current_values by replacing elements by sparse
+
+        For every element in merge_values where the value is equal to sparse
+        replace it by the element in current_values.
+
+        The result can be obtained in merge_values as the list is modified
+        in-place (and passed by reference).
+        """
+
         for idx, value in enumerate(merge_values):
             if sequence_not_str(value):
-                self._merge_write_delimiter(merge_values[idx], current_values[idx], sparse)
+                self._merge_write_delimiter(
+                    merge_values[idx], current_values[idx], sparse
+                )
             elif value == sparse:
                 merge_values[idx] = current_values[idx]
 
     def _merge_write_mask(
-        self, merge_values: List[any], current_values: List[any], mask: List[any]
+        self, merge_values: List[any], current_values: List[any],
+        mask: List[any]
     ):
+        """Merge merge_values and current_values by replacing elements by mask
+
+        For every element in merge_values where the element in mask is false
+        replace it by the element in current_values.
+
+        The result can be obtained in merge_values as the list is modified
+        in-place (and passed by reference).
+        """
+
         for idx, value in enumerate(merge_values):
             if sequence_not_str(value):
-                self._merge_write_mask(merge_values[idx], current_values[idx], mask[idx])
+                self._merge_write_mask(
+                    merge_values[idx], current_values[idx], mask[idx]
+                )
             elif not mask[idx]:
                 merge_values[idx] = current_values[idx]
 
diff --git a/tangostationcontrol/tangostationcontrol/integration_test/default/statistics/test_writer_sst.py b/tangostationcontrol/tangostationcontrol/integration_test/default/statistics/test_writer_sst.py
index dfb548d3e..0f12a113a 100644
--- a/tangostationcontrol/tangostationcontrol/integration_test/default/statistics/test_writer_sst.py
+++ b/tangostationcontrol/tangostationcontrol/integration_test/default/statistics/test_writer_sst.py
@@ -82,7 +82,7 @@ class TestStatisticsWriterSST(BaseIntegrationTestCase):
                     '2021-09-20T12:17:40.000+00:00'
                 )
                 self.assertIsNotNone(stat)
-                self.assertEqual("0.1", stat.station_version_id)
+                self.assertEqual("0.1.1", stat.station_version_id)
                 self.assertEqual("0.1", stat.writer_version_id)
     
     def test_insert_tango_SST_statistics(self):
diff --git a/tangostationcontrol/tangostationcontrol/test/common/test_type_checking.py b/tangostationcontrol/tangostationcontrol/test/common/test_type_checking.py
new file mode 100644
index 000000000..051001fc8
--- /dev/null
+++ b/tangostationcontrol/tangostationcontrol/test/common/test_type_checking.py
@@ -0,0 +1,77 @@
+# -*- coding: utf-8 -*-
+#
+# This file is part of the LOFAR 2.0 Station Software
+#
+#
+#
+# Distributed under the terms of the APACHE license.
+# See LICENSE.txt for more info.
+import numpy
+
+from tangostationcontrol.common import type_checking
+
+from tangostationcontrol.test import base
+
+
+class TestTypeChecking(base.TestCase):
+
+    @staticmethod
+    def subscriptable(obj):
+        return hasattr(obj, '__getitem__')
+
+    @staticmethod
+    def iterable(obj):
+        return hasattr(obj, '__iter__')
+
+    @staticmethod
+    def positional_ordering(obj):
+        try:
+            obj[0]
+            return True
+        except Exception:
+            return False
+
+    def sequence_test(self, obj):
+        """Test object is sequence based on properties and verify is_sequence"""
+
+        result = (
+            self.subscriptable(obj) & self.iterable(obj) &
+            self.positional_ordering(obj)
+        )
+
+        self.assertEqual(
+            result, type_checking.is_sequence(obj),
+            F"Test failed for type {type(obj)}"
+        )
+
+    def test_is_sequence_for_types(self):
+        """Types to be tested by is_sequence"""
+
+        test_types = [
+            (False,),
+            {"test": "test"},
+            [False],
+            {"test"},
+            numpy.array([1, 2, 3]),
+        ]
+
+        for test in test_types:
+            self.sequence_test(test)
+
+    def test_is_sequence_not_str(self):
+        """Types test for sequence_not_str, must be false"""
+
+        t_bytearray = bytearray([0, 5, 255])
+        test_types = [
+            str(""),
+            bytes(t_bytearray),
+            t_bytearray
+        ]
+
+        for test in test_types:
+            self.assertFalse(type_checking.sequence_not_str(test))
+
+    def test_type_not_sequence(self):
+        test = [str]
+        self.asserFalse(type_checking.type_not_sequence(test))
+        self.asserTrue(type_checking.type_not_sequence(test[0]))
diff --git a/tangostationcontrol/tangostationcontrol/test/devices/test_antennafield_device.py b/tangostationcontrol/tangostationcontrol/test/devices/test_antennafield_device.py
index 85d10cfab..465db52aa 100644
--- a/tangostationcontrol/tangostationcontrol/test/devices/test_antennafield_device.py
+++ b/tangostationcontrol/tangostationcontrol/test/devices/test_antennafield_device.py
@@ -345,43 +345,43 @@ class TestAntennaToRecvMapper(base.TestCase):
         actual = mapper.map_write("HBAT_PWR_on_RW", set_values)
         numpy.testing.assert_equal(expected, actual)
 
-    def test_merge_write(self):
-        """Verify all None fields are replaced by merge_write if no control"""
-
-        mapper = AntennaToRecvMapper(
-            self.CONTROL_NOT_CONNECTED, self.POWER_NOT_CONNECTED, 1
-        )
-
-        merge_values = [[None] * 32] * 96
-        current_values = [[False] * 32] * 96
-
-        mapper.merge_write(merge_values, current_values)
-        numpy.testing.assert_equal(merge_values, current_values)
-
-        results = []
-        for _i in range(25):
-            start_time = time.monotonic_ns()
-            mapper.merge_write(merge_values, current_values)
-            stop_time = time.monotonic_ns()
-            results.append(stop_time - start_time)
-
-        logging.error(
-            f"Merge write performance: Median {statistics.median(results) / 1.e9} "
-            f"Stdev {statistics.stdev(results) / 1.e9}"
-        )
-
-    def test_merge_write_values(self):
-        """Verify all fields with values are retained by merge_write"""
-
-        mapper = AntennaToRecvMapper(
-            self.CONTROL_NOT_CONNECTED, self.POWER_NOT_CONNECTED, 1
-        )
-
-        merge_values = [[True] * 32] * 2 + [[None] * 32] * 94
-        current_values = [[True] * 32] * 2 + [[False] * 32] * 94
-
-        mapper.merge_write(merge_values, current_values)
-        numpy.testing.assert_equal(merge_values, current_values)
+    # def test_merge_write(self):
+    #     """Verify all None fields are replaced by merge_write if no control"""
+    #
+    #     mapper = AntennaToRecvMapper(
+    #         self.CONTROL_NOT_CONNECTED, self.POWER_NOT_CONNECTED, 1
+    #     )
+    #
+    #     merge_values = [[None] * 32] * 96
+    #     current_values = [[False] * 32] * 96
+    #
+    #     mapper.merge_write(merge_values, current_values)
+    #     numpy.testing.assert_equal(merge_values, current_values)
+    #
+    #     results = []
+    #     for _i in range(25):
+    #         start_time = time.monotonic_ns()
+    #         mapper.merge_write(merge_values, current_values)
+    #         stop_time = time.monotonic_ns()
+    #         results.append(stop_time - start_time)
+    #
+    #     logging.error(
+    #         f"Merge write performance: Median {statistics.median(results) / 1.e9} "
+    #         f"Stdev {statistics.stdev(results) / 1.e9}"
+    #     )
+    #
+    # def test_merge_write_values(self):
+    #     """Verify all fields with values are retained by merge_write"""
+    #
+    #     mapper = AntennaToRecvMapper(
+    #         self.CONTROL_NOT_CONNECTED, self.POWER_NOT_CONNECTED, 1
+    #     )
+    #
+    #     merge_values = [[True] * 32] * 2 + [[None] * 32] * 94
+    #     current_values = [[True] * 32] * 2 + [[False] * 32] * 94
+    #
+    #     mapper.merge_write(merge_values, current_values)
+    #     numpy.testing.assert_equal(merge_values, current_values)
 
 
 class TestAntennafieldDevice(device_base.DeviceTestCase):
diff --git a/tangostationcontrol/tangostationcontrol/test/devices/test_lofar_device.py b/tangostationcontrol/tangostationcontrol/test/devices/test_lofar_device.py
index 38d599f9f..3c66e6176 100644
--- a/tangostationcontrol/tangostationcontrol/test/devices/test_lofar_device.py
+++ b/tangostationcontrol/tangostationcontrol/test/devices/test_lofar_device.py
@@ -66,3 +66,17 @@ class TestLofarDevice(device_base.DeviceTestCase):
             with self.assertRaises(DevFailed):
                 proxy.disable_hardware()
 
+    def test_atomic_read_modify_write(self):
+        """Test atomic read modify write for attribute"""
+
+        class AttributeLofarDevice(lofar_device.lofar_device):
+
+            BOOL_ARRAY_DIM = 32
+
+            # Just for demo, do not use class variables to store attribute state
+            _bool_array = [False] * BOOL_ARRAY_DIM
+
+
+            @attribute(dtype=(bool,), max_dim_x=BOOL_ARRAY_DIM)
+            def bool_array(self):
+                return self._bool_array
-- 
GitLab