diff --git a/tangostationcontrol/tangostationcontrol/clients/snmp_client.py b/tangostationcontrol/tangostationcontrol/clients/snmp_client.py index b161d9b42281b6ef981eee31ebfec6473cb701a9..0d9535134050d5b7fdd94e338fadf86d3a1de631 100644 --- a/tangostationcontrol/tangostationcontrol/clients/snmp_client.py +++ b/tangostationcontrol/tangostationcontrol/clients/snmp_client.py @@ -179,55 +179,67 @@ class snmp_attribute: errorIndication, errorStatus, errorIndex, *varBinds = self.comm.setter(write_obj) - def convert(self, varBinds): + def convert(self, var_binds): """ get all the values in a list, make sure to convert specific types that dont want to play nicely """ + values = [] - vals = [] + varBinds = var_binds[0] - varBinds = varBinds[0] + for var_bind in varBinds: + value = self._convert_var_bind_value(var_bind[1]) + values.append(value) - for varBind in varBinds: + if self.is_scalar: + values = values[0] + + return values + + def _convert_var_bind_value(self, value): + def is_an_hlapi_integer(value): + return isinstance(value, (hlapi.Integer32, hlapi.Integer)) - # Some MIB's used custom types, some dont. Custom types are merely wrapped base types. - varbind_types = varBind[1].__class__.__bases__ + (type(varBind[1]),) + def is_an_hlapi_number_type(value): + return isinstance(value, (hlapi.TimeTicks, hlapi.Counter32, hlapi.Gauge32, + hlapi.Integer32, hlapi.Integer)) - snmp_type = None + def is_an_hlapi_string_type(value): + return isinstance(value, (hlapi.OctetString, hlapi.ObjectIdentity)) - # find if one of the base types is present. - for i in varbind_types: - if i in snmp_to_numpy_dict.keys(): - snmp_type = i + def needs_conversion_from_integer_to_str(value): + return is_an_hlapi_integer(value) and self.dtype == str - if snmp_type is None: - raise TypeError(f"Error: did not find a valid snmp type. Got: {varbind_types}, expected one of: '{snmp_to_numpy_dict.keys()}'") + def needs_conversion_from_ipaddress_to_str(value): + return isinstance(value, hlapi.IpAddress) and self.dtype == str - if snmp_type is hlapi.IpAddress: - # IpAddress values get printed as their raw value but in hex (7F 20 20 01 for 127.0.0.1 for example) - vals.append(varBind[1].prettyPrint()) + def needs_conversion_from_number_to_int64(value): + return is_an_hlapi_number_type(value) and self.dtype == numpy.int64 - elif (snmp_type is hlapi.Integer32 or snmp_type is hlapi.Integer) and self.dtype == str: - # Integers can have 'named values', Where a value can be translated to a specific name. A dict basically - # Example: {1: "other", 2: "invalid", 3: "dynamic", 4: "static",} + def needs_conversion_from_string_to_str(value): + return is_an_hlapi_string_type(value) and self.dtype == str - if varBind[1].namedValues == {}: - # An empty dict {} means no namedValue's are present. - vals.append(snmp_to_numpy_dict[snmp_type](varBind[1])) - else: - # append the named values string instead of the raw number. - vals.append(varBind[1].prettyPrint()) + def convert_integer_to_str(value): + if value.namedValues: + result = value.prettyPrint() else: - # convert from the funky pysnmp types to numpy types and then append - value = snmp_to_numpy_dict[snmp_type](varBind[1]) + result = numpy.int64(value) - # scale the value correctly and append. - vals.append(value * self.scaling_factor) + return result - if self.is_scalar: - vals = vals[0] + if needs_conversion_from_ipaddress_to_str(value): + result = value.prettyPrint() + elif needs_conversion_from_integer_to_str(value): + result = convert_integer_to_str(value) + elif needs_conversion_from_number_to_int64(value): + result = numpy.int64(value) * self.scaling_factor + elif needs_conversion_from_string_to_str(value): + result = str(value) + else: + raise TypeError(f"Error: did not find a valid snmp type. Got: {type(value)}, expected one of: '{snmp_to_numpy_dict.keys()}'") + + return result - return vals class mib_loader: diff --git a/tangostationcontrol/tangostationcontrol/statistics/writer.py b/tangostationcontrol/tangostationcontrol/statistics/writer.py index d824151962aafa3d1834ceb7236f27b5a68007ec..23d3869959ccad60d0fa2d9c098aef259b6b9f88 100644 --- a/tangostationcontrol/tangostationcontrol/statistics/writer.py +++ b/tangostationcontrol/tangostationcontrol/statistics/writer.py @@ -94,25 +94,7 @@ def _start_loop(receiver, writer, reconnect, filename, device): """Main loop""" try: while True: - try: - packet = receiver.get_packet() - writer.next_packet(packet, device) - except EOFError: - if reconnect and not filename: - logger.warning("Connection lost, attempting to reconnect") - while True: - try: - receiver.reconnect() - except Exception as e: - logger.warning(f"Could not reconnect: {e.__class__.__name__}: {e}") - time.sleep(10) - else: - break - logger.warning("Reconnected! Resuming operations") - else: - logger.info("End of input.") - raise SystemExit - + _receive_packets(receiver, writer, reconnect, filename, device) except KeyboardInterrupt: # user abort, don't complain logger.warning("Received keyboard interrupt. Stopping.") @@ -120,6 +102,27 @@ def _start_loop(receiver, writer, reconnect, filename, device): writer.close_writer() +def _receive_packets(receiver, writer, reconnect, filename, device): + try: + packet = receiver.get_packet() + writer.next_packet(packet, device) + except EOFError: + if reconnect and not filename: + logger.warning("Connection lost, attempting to reconnect") + while True: + try: + receiver.reconnect() + except Exception as e: + logger.warning(f"Could not reconnect: {e.__class__.__name__}: {e}") + time.sleep(10) + else: + break + logger.warning("Reconnected! Resuming operations") + else: + logger.info("End of input.") + raise SystemExit + + def main(): parser = _create_parser() diff --git a/tangostationcontrol/tangostationcontrol/test/clients/test_opcua_client.py b/tangostationcontrol/tangostationcontrol/test/clients/test_opcua_client.py index 63daef88819ad97ebc7b0cb6ddf2c7bda6c86a75..055a01430bb31617967bbacafe97ec3ede81a809 100644 --- a/tangostationcontrol/tangostationcontrol/test/clients/test_opcua_client.py +++ b/tangostationcontrol/tangostationcontrol/test/clients/test_opcua_client.py @@ -1,5 +1,6 @@ import numpy import asyncua +import asyncio import io import asynctest @@ -175,7 +176,7 @@ class TestOPCua(base.AsyncTestCase): This tests the read functions. """ - async def get_flat_value(): + async def get_flat_value(j, i): return self._get_test_value(j, i.numpy_type).flatten() for j in DIMENSION_TESTS: @@ -188,7 +189,7 @@ class TestOPCua(base.AsyncTestCase): test = opcua_client.ProtocolAttribute(m_node, j[0], 0, opcua_client.numpy_to_OPCua_dict[i.numpy_type]) else: test = opcua_client.ProtocolAttribute(m_node, j[1], j[0], opcua_client.numpy_to_OPCua_dict[i.numpy_type]) - m_node.get_value = get_flat_value + m_node.get_value.return_value = get_flat_value(j, i) val = await test.read_function() comp = val == self._get_test_value(j, i.numpy_type) @@ -198,14 +199,14 @@ class TestOPCua(base.AsyncTestCase): """ Test whether unicode characters are replaced by '?'. """ - async def get_unicode_value(): + async def get_unicode_value(dims): return self._wrap_dims(b'foo \xef\xbf\xbd bar'.decode('utf-8'), dims) # test 0-2 dimensions of strings for dims in range(0,2): m_node = asynctest.asynctest.CoroutineMock() - m_node.get_value = get_unicode_value + m_node.get_value.return_value = get_unicode_value(dims) # create the ProtocolAttribute to test test = opcua_client.ProtocolAttribute(m_node, 1, 0, opcua_client.numpy_to_OPCua_dict[str]) @@ -247,8 +248,6 @@ class TestOPCua(base.AsyncTestCase): if numpy_type not in [str, bool]: self.assertEqual(numpy_type().itemsize, getattr(asyncua.ua.ua_binary.Primitives, opcua_type.name).size, msg=f"Conversion {numpy_type} -> {opcua_type} failed: precision mismatch") - - async def test_write(self): """ Test the writing of values by instantiating a ProtocolAttribute attribute, and calling the write function. @@ -256,7 +255,7 @@ class TestOPCua(base.AsyncTestCase): This allows the code to compare what values we want to write and what values would be given to a server. """ - async def compare_values(val): + async def compare_values(val, j, i): """ comparison function that replaces `set_data_value` inside the attributes write function """ # test valuest val = val.tolist() if type(val) == numpy.ndarray else val @@ -268,22 +267,23 @@ class TestOPCua(base.AsyncTestCase): comp = val == self._get_mock_value(self._get_test_value(j, i.numpy_type), i.numpy_type) self.assertTrue(comp, "value attempting to write unequal to expected value: \n\tgot: {} \n\texpected: {}".format(val, self._get_mock_value(self._get_test_value(j, i.numpy_type), i.numpy_type))) + m_node = asynctest.asynctest.CoroutineMock() + m_node.set_data_value.return_value = asyncio.Future() + m_node.set_data_value.return_value.set_result(None) + # for all dimensionalities for j in DIMENSION_TESTS: #for all datatypes for i in ATTR_TEST_TYPES: - m_node = asynctest.asynctest.CoroutineMock() - # create the protocolattribute if len(j) == 1: test = opcua_client.ProtocolAttribute(m_node, j[0], 0, opcua_client.numpy_to_OPCua_dict[i.numpy_type]) else: test = opcua_client.ProtocolAttribute(m_node, j[1], j[0], opcua_client.numpy_to_OPCua_dict[i.numpy_type]) - # replace the `set_data_value`, usualy responsible for communicating with the server with the `compare_values` function. - m_node.set_data_value = compare_values - # call the write function with the test values await test.write_function(self._get_test_value(j, i.numpy_type)) + + compare_values(m_node.call_args, j, i) diff --git a/tangostationcontrol/tangostationcontrol/test/clients/test_snmp_client.py b/tangostationcontrol/tangostationcontrol/test/clients/test_snmp_client.py index a1bbeb8d1699f23b5f960d252451bf81770797b9..ee04a1b3dfffab25d012ba81c0a80d3e61c1de50 100644 --- a/tangostationcontrol/tangostationcontrol/test/clients/test_snmp_client.py +++ b/tangostationcontrol/tangostationcontrol/test/clients/test_snmp_client.py @@ -31,39 +31,46 @@ class server_imitator: """ if dims == self.DIM_LIST["scalar"]: - if snmp_type is hlapi.ObjectIdentity: - read_val = [(snmp_type("1.3.6.1.2.1.1.1.0"),)] - elif snmp_type is hlapi.IpAddress: - read_val = [(None, snmp_type("1.1.1.1"),)] - elif snmp_type is hlapi.OctetString: - read_val = [(None, snmp_type("1"),)] - else: - read_val = [(None, snmp_type(1),)] - + read_val = self._get_return_val_for_scalar(snmp_type) elif dims == self.DIM_LIST["spectrum"]: - if snmp_type is hlapi.ObjectIdentity: - read_val = [] - for _i in range(dims[0]): - read_val.append((None, snmp_type(f"1.3.6.1.2.1.1.1.0.1"))) - elif snmp_type is hlapi.IpAddress: - read_val = [] - for _i in range(dims[0]): - read_val.append((None, snmp_type(f"1.1.1.1"))) - elif snmp_type is hlapi.OctetString: - read_val = [] - for _i in range(dims[0]): - read_val.append((None, snmp_type("1"))) - else: - read_val = [] - for _i in range(dims[0]): - read_val.append((None, snmp_type(1))) - + read_val = self._get_return_val_for_spectrum(snmp_type, dims) else: raise Exception("Image not supported :(") return read_val + def _get_return_val_for_scalar(self, snmp_type : type): + if snmp_type is hlapi.ObjectIdentity: + read_val = [(snmp_type("1.3.6.1.2.1.1.1.0"),)] + elif snmp_type is hlapi.IpAddress: + read_val = [(None, snmp_type("1.1.1.1"),)] + elif snmp_type is hlapi.OctetString: + read_val = [(None, snmp_type("1"),)] + else: + read_val = [(None, snmp_type(1),)] + + return read_val + def _get_return_val_for_spectrum(self, snmp_type : type, dims : tuple): + if snmp_type is hlapi.ObjectIdentity: + read_val = [] + for _i in range(dims[0]): + read_val.append((None, snmp_type(f"1.3.6.1.2.1.1.1.0.1"))) + elif snmp_type is hlapi.IpAddress: + read_val = [] + for _i in range(dims[0]): + read_val.append((None, snmp_type(f"1.1.1.1"))) + elif snmp_type is hlapi.OctetString: + read_val = [] + for _i in range(dims[0]): + read_val.append((None, snmp_type("1"))) + else: + read_val = [] + for _i in range(dims[0]): + read_val.append((None, snmp_type(1))) + + return read_val + def val_check(self, snmp_type : type, dims : tuple): """ provides the values we expect and would provide to the attribute after converting the @@ -143,21 +150,14 @@ class TestSNMP(base.TestCase): """ Attempts to write a value to an SNMP server, but instead intercepts it and compared whether the values is as expected. """ - - def loop_test(*value): - res_lst.append(value[1]) - return None, None, None, server.get_return_val(i, server.DIM_LIST[j]) - server = server_imitator() obj_type = hlapi.ObjectType - for j in server.DIM_LIST: for i in server.SNMP_TO_NUMPY_DICT: # mocks the return value of the next function in snmp_client.SNMP_comm.setter m_next.return_value = (None, None, None, server.get_return_val(i, server.DIM_LIST[j])) - def __fakeInit__(self): pass @@ -170,21 +170,24 @@ class TestSNMP(base.TestCase): hlapi.ObjectType = obj_type snmp_attr = snmp_attribute(comm=m_comms, mib="test", name="test", idx=0, dtype=server.SNMP_TO_NUMPY_DICT[i], dim_x=server.DIM_LIST[j][0], dim_y=server.DIM_LIST[j][1]) - res_lst = [] + hlapi.ObjectType = mock.MagicMock() - hlapi.ObjectType = loop_test + hlapi.ObjectType.return_value = (None, None, None, server.get_return_val(i, server.DIM_LIST[j])) # call the write function. This function should now call m_ObjectType itself. snmp_attr.write_function(set_val) + # get a value to compare the value we got against + checkval = server.val_check(i, server.DIM_LIST[j]) + + res_lst = [args[1] for args, _ in hlapi.ObjectType.call_args_list if args] if len(res_lst) == 1: res_lst = res_lst[0] + self.assertEqual(checkval, res_lst) - # get a value to compare the value we got against - checkval = server.val_check(i, server.DIM_LIST[j]) + hlapi.ObjectType = obj_type - self.assertEqual(checkval, res_lst, f"During test {j} {i}; Expected: {checkval}, got: {res_lst}") @mock.patch('tangostationcontrol.clients.snmp_client.SNMP_comm.getter') def test_named_value(self, m_next):