From 15de2000aad5a33242cf262765f793e16fd4b7cf Mon Sep 17 00:00:00 2001
From: Auke Klazema <klazema@astron.nl>
Date: Thu, 11 Aug 2022 12:55:50 +0200
Subject: [PATCH] L2SS-873: Use proper scoping of variables in tests

---
 .../test/clients/test_opcua_client.py         | 24 +++++++++----------
 .../test/clients/test_snmp_client.py          | 21 +++++++---------
 2 files changed, 21 insertions(+), 24 deletions(-)

diff --git a/tangostationcontrol/tangostationcontrol/test/clients/test_opcua_client.py b/tangostationcontrol/tangostationcontrol/test/clients/test_opcua_client.py
index ab20b2382..0882411fe 100644
--- a/tangostationcontrol/tangostationcontrol/test/clients/test_opcua_client.py
+++ b/tangostationcontrol/tangostationcontrol/test/clients/test_opcua_client.py
@@ -1,5 +1,6 @@
 import numpy
 import asyncua
+import asyncio
 import io
 
 import asynctest
@@ -175,7 +176,7 @@ class TestOPCua(base.AsyncTestCase):
         This tests the read functions.
         """
 
-        async def get_flat_value():
+        async def get_flat_value(j, i):
             return self._get_test_value(j, i.numpy_type).flatten()
 
         for j in DIMENSION_TESTS:
@@ -188,7 +189,7 @@ class TestOPCua(base.AsyncTestCase):
                     test = opcua_client.ProtocolAttribute(m_node, j[0], 0, opcua_client.numpy_to_OPCua_dict[i.numpy_type])
                 else:
                     test = opcua_client.ProtocolAttribute(m_node, j[1], j[0], opcua_client.numpy_to_OPCua_dict[i.numpy_type])
-                m_node.get_value = get_flat_value
+                m_node.get_value.return_value = get_flat_value(j, i)
                 val = await test.read_function()
 
                 comp = val == self._get_test_value(j, i.numpy_type)
@@ -198,14 +199,14 @@ class TestOPCua(base.AsyncTestCase):
         """
         Test whether unicode characters are replaced by '?'.
         """
-        async def get_unicode_value():
+        async def get_unicode_value(dims):
             return self._wrap_dims(b'foo \xef\xbf\xbd bar'.decode('utf-8'), dims)
         
         # test 0-2 dimensions of strings
         for dims in range(0,2):
 
             m_node = asynctest.asynctest.CoroutineMock()
-            m_node.get_value = get_unicode_value
+            m_node.get_value.return_value = get_unicode_value(dims)
 
             # create the ProtocolAttribute to test
             test = opcua_client.ProtocolAttribute(m_node, 1, 0, opcua_client.numpy_to_OPCua_dict[str])
@@ -247,8 +248,6 @@ class TestOPCua(base.AsyncTestCase):
             if numpy_type not in [str, bool]:
                 self.assertEqual(numpy_type().itemsize, getattr(asyncua.ua.ua_binary.Primitives, opcua_type.name).size, msg=f"Conversion {numpy_type} -> {opcua_type} failed: precision mismatch")
 
-
-
     async def test_write(self):
         """
         Test the writing of values by instantiating a ProtocolAttribute attribute, and calling the write function.
@@ -256,7 +255,7 @@ class TestOPCua(base.AsyncTestCase):
         This allows the code to compare what values we want to write and what values would be given to a server.
         """
 
-        async def compare_values(val):
+        async def compare_values(val, j, i):
             """ comparison function that replaces `set_data_value` inside the attributes write function """
             # test valuest
             val = val.tolist() if type(val) == numpy.ndarray else val
@@ -268,22 +267,23 @@ class TestOPCua(base.AsyncTestCase):
                 comp = val == self._get_mock_value(self._get_test_value(j, i.numpy_type), i.numpy_type)
                 self.assertTrue(comp, "value attempting to write unequal to expected value: \n\tgot: {} \n\texpected: {}".format(val, self._get_mock_value(self._get_test_value(j, i.numpy_type), i.numpy_type)))
 
+        m_node = asynctest.asynctest.CoroutineMock()
+        m_node.set_data_value.return_value = asyncio.Future()
+        m_node.set_data_value.return_value.set_result(None)
+
         # for all dimensionalities
         for j in DIMENSION_TESTS:
 
             #for all datatypes
             for i in ATTR_TEST_TYPES:
 
-                m_node = asynctest.asynctest.CoroutineMock()
-
                 # create the protocolattribute
                 if len(j) == 1:
                     test = opcua_client.ProtocolAttribute(m_node, j[0], 0, opcua_client.numpy_to_OPCua_dict[i.numpy_type])
                 else:
                     test = opcua_client.ProtocolAttribute(m_node, j[1], j[0], opcua_client.numpy_to_OPCua_dict[i.numpy_type])
 
-                # replace the `set_data_value`, usualy responsible for communicating with the server with the `compare_values` function.
-                m_node.set_data_value = compare_values
-
                 # call the write function with the test values
                 await test.write_function(self._get_test_value(j, i.numpy_type))
+
+                compare_values(m_node.call_args, j, i)
diff --git a/tangostationcontrol/tangostationcontrol/test/clients/test_snmp_client.py b/tangostationcontrol/tangostationcontrol/test/clients/test_snmp_client.py
index a1bbeb8d1..2b2866246 100644
--- a/tangostationcontrol/tangostationcontrol/test/clients/test_snmp_client.py
+++ b/tangostationcontrol/tangostationcontrol/test/clients/test_snmp_client.py
@@ -143,21 +143,14 @@ class TestSNMP(base.TestCase):
         """
         Attempts to write a value to an SNMP server, but instead intercepts it and compared whether the values is as expected.
         """
-
-        def loop_test(*value):
-            res_lst.append(value[1])
-            return None, None, None, server.get_return_val(i, server.DIM_LIST[j])
-
         server = server_imitator()
         obj_type = hlapi.ObjectType
 
-
         for j in server.DIM_LIST:
             for i in server.SNMP_TO_NUMPY_DICT:
                 # mocks the return value of the next function in snmp_client.SNMP_comm.setter
                 m_next.return_value = (None, None, None, server.get_return_val(i, server.DIM_LIST[j]))
 
-
                 def __fakeInit__(self):
                     pass
 
@@ -170,21 +163,25 @@ class TestSNMP(base.TestCase):
                     hlapi.ObjectType = obj_type
                     snmp_attr = snmp_attribute(comm=m_comms, mib="test", name="test", idx=0, dtype=server.SNMP_TO_NUMPY_DICT[i], dim_x=server.DIM_LIST[j][0], dim_y=server.DIM_LIST[j][1])
 
-                    res_lst = []
+                    hlapi.ObjectType = mock.MagicMock()
 
-                    hlapi.ObjectType = loop_test
+                    hlapi.ObjectType.return_value = (None, None, None, server.get_return_val(i, server.DIM_LIST[j]))
 
                     # call the write function. This function should now call m_ObjectType itself.
                     snmp_attr.write_function(set_val)
 
+                    # get a value to compare the value we got against
+                    checkval = server.val_check(i, server.DIM_LIST[j])
+
+                    res_lst = [call.args[1] for call in hlapi.ObjectType.call_args_list if call.args]
+
                     if len(res_lst) == 1:
                         res_lst = res_lst[0]
 
+                    self.assertEqual(checkval, res_lst)
 
-                    # get a value to compare the value we got against
-                    checkval = server.val_check(i, server.DIM_LIST[j])
+        hlapi.ObjectType = obj_type
 
-                    self.assertEqual(checkval, res_lst, f"During test {j} {i}; Expected: {checkval}, got: {res_lst}")
 
     @mock.patch('tangostationcontrol.clients.snmp_client.SNMP_comm.getter')
     def test_named_value(self, m_next):
-- 
GitLab