Skip to content
Snippets Groups Projects
Commit 776d8a66 authored by Taya Snijder's avatar Taya Snijder
Browse files

added early version of SNMP client/device

parent 99e3d081
No related branches found
No related tags found
1 merge request!26Resolve #2021 "04 16 branched from master snmp device"
# -*- coding: utf-8 -*-
#
# This file is part of the PCC project
#
#
#
# Distributed under the terms of the APACHE license.
# See LICENSE.txt for more info.
""" PCC Device Server for LOFAR2.0
"""
# PyTango imports
from tango.server import run
from tango.server import device_property
# Additional import
from clients.SNMP_client import SNMP_client
from src.attribute_wrapper import *
from src.hardware_device import *
__all__ = ["example_device", "main"]
class example_device(hardware_device):
"""
**Properties:**
- Device Property
SNMP_community
- Type:'DevString'
SNMP_host
- Type:'DevULong'
SNMP_timeout
- Type:'DevDouble'
"""
# -----------------
# Device Properties
# -----------------
# SNMP_community = device_property(
# dtype='DevString',
# mandatory=True
# )
#
# SNMP_host = device_property(
# dtype='DevString',
# mandatory=True
# )
#
# SNMP_timeout = device_property(
# dtype='DevDouble',
# mandatory=True
# )
SNMP_community = b"public"
SNMP_host = "127.0.0.1"
SNMP_timeout = 5.0
# ----------
# Attributes
# ----------
# simple scalar - description
attr1 = attribute_wrapper(comms_annotation={"oids": "1.3.6.1.2.1.1.1.0"}, datatype=numpy.str_, access=AttrWriteType.READ_WRITE)
# simple scalar uptime
attr2 = attribute_wrapper(comms_annotation={"oids": "1.3.6.1.2.1.1.3.0"}, datatype=numpy.int64, access=AttrWriteType.READ_WRITE)
# simple scalar with name
attr3 = attribute_wrapper(comms_annotation={"oids": "1.3.6.1.2.1.1.5.0"}, datatype=numpy.str_, access=AttrWriteType.READ_WRITE)
#spectrum with all elements
attr4 = attribute_wrapper(comms_annotation={"oids": ["1.3.6.1.2.1.2.2.1.1.1", "1.3.6.1.2.1.2.2.1.1.2", "1.3.6.1.2.1.2.2.1.1.3"]}, dims=(3,), datatype=numpy.int64)
#inferred spectrum
attr5 = attribute_wrapper(comms_annotation={"oids": ".1.3.6.1.2.1.2.2.1.1"}, dims=(3,), datatype=numpy.int64)
def always_executed_hook(self):
"""Method always executed before any TANGO command is executed."""
pass
def delete_device(self):
"""Hook to delete resources allocated in init_device.
This method allows for any memory or other resources allocated in the
init_device method to be released. This method is called by the device
destructor and by the device Init command (a Tango built-in).
"""
self.debug_stream("Shutting down...")
self.Off()
self.debug_stream("Shut down. Good bye.")
# --------
# overloaded functions
# --------
def initialise(self):
""" user code here. is called when the state is set to STANDBY """
# set up the SNMP ua client
self.snmp_manager = SNMP_client(self.SNMP_community, self.SNMP_host, self.SNMP_timeout, self.Fault, self)
# map the attributes to the OPC ua comm client
for i in self.attr_list():
i.set_comm_client(self.snmp_manager)
self.snmp_manager.start()
# --------
# Commands
# --------
# ----------
# Run server
# ----------
def main(args=None, **kwargs):
"""Main function of the PCC module."""
return run((example_device,), args=args, **kwargs)
if __name__ == '__main__':
main()
from src.comms_client import *
import snmp
__all__ = ["SNMP_client"]
snmp_to_numpy_dict = {
snmp.types.INTEGER: numpy.int64,
snmp.types.TimeTicks: numpy.int64,
snmp.types.OCTET_STRING: str,
snmp.types.OID: str
}
# numpy_to_snmp_dict = {
# numpy.int64,
# numpy.int64,
# str,
# }
class SNMP_client(CommClient):
"""
messages to keep a check on the connection. On connection failure, reconnects once.
"""
def start(self):
super().start()
def __init__(self, community, host, timeout, fault_func, streams, try_interval=2):
"""
Create the SNMP and connect() to it
"""
super().__init__(fault_func, streams, try_interval)
self.community = community
self.host = host
self.manager = snmp.Manager(community, host, timeout)
# Explicitly connect
if not self.connect():
# hardware or infra is down -- needs fixing first
fault_func()
return
def connect(self):
"""
Try to connect to the client
"""
self.streams.debug_stream("Connecting to server %s %s", self.community, self.host)
self.connected = True
return True
def disconnect(self):
"""
disconnect from the client
"""
self.connected = False # always force a reconnect, regardless of a successful disconnect
def ping(self):
"""
ping the client to make sure the connection with the client is still functional.
"""
pass
def _setup_annotation(self, annotation):
"""
This class's Implementation of the get_mapping function. returns the read and write functions
"""
if isinstance(annotation, dict):
# check if required path inarg is present
if annotation.get('oids') is None:
AssertionError("SNMP get attributes require an oid")
oids = annotation.get("oids") # required
if annotation.get('host') is None:
AssertionError("SNMP get attributes require an host")
host = annotation.get("host") # required
else:
TypeError("SNMP attributes require a dict with oid and adress")
return
return host, oids
def setup_value_conversion(self, attribute):
"""
gives the client access to the attribute_wrapper object in order to access all data it could potentially need.
the OPC ua read/write functions require the dimensionality and the type to be known
"""
dim_x = attribute.dim_x
dim_y = attribute.dim_y
return dim_x, dim_y
def get_oids(self, x, y, in_oid):
if x == 0:
x = 1
if y == 0:
y = 1
nof_oids = x * y
if nof_oids == 1:
# is scalar
if type(in_oid) is not list:
# for ease of handling put single oid in a 1 element list
in_oid = [in_oid]
return in_oid
elif type(in_oid) is list and len(in_oid) == nof_oids:
# already is an array and of the right length
return in_oid
elif type(in_oid) is list and len(in_oid) != nof_oids:
# already is an array but the wrong length. Unable to handle this
raise ValueError("SNMP oids need to either be a single value or an array the size of the attribute dimensions. got: {} expected: {}x{}={}".format(len(in_oid),x,y,x*y))
else:
out_oids = []
for i in range(nof_oids):
out_oids.append(in_oid + ".{}".format(i+1))
print(out_oids)
return out_oids
def setup_attribute(self, annotation, attribute):
"""
MANDATORY function: is used by the attribute wrapper to get read/write functions. must return the read and write functions
"""
# process the annotation
host, oids = self._setup_annotation(annotation)
# get all the necessary data to set up the read/write functions from the attribute_wrapper
dim_x, dim_y = self.setup_value_conversion(attribute)
oids = self.get_oids(dim_x, dim_y, oids)
def _read_function():
vars = self.manager.get(host, *oids)
value = []
for i in vars:
value.append(str(i.value))
# value.append(snmp_to_numpy_dict[type(i.value)](i.value))
return value
def _write_function(value):
self.manager.set(host, oids, value)
# return the read/write functions
return _read_function, _write_function
class snmp_get:
"""
This class provides a small wrapper for the OPC ua read/write functions in order to better organise the code
"""
def __init__(self, host, oid, dim_x, dim_y, snmp_type):
self.host = host
self.oid = oid
self.dim_y = dim_y
self.dim_x = dim_x
self.snmp_type = snmp_type
def read_function(self):
"""
Read_R function
"""
value = numpy.array(self.node.get_value())
if self.dim_y != 0:
value = numpy.array(numpy.split(value, indices_or_sections=self.dim_y))
else:
value = numpy.array(value)
return value
def write_function(self, value):
"""
write_RW function
"""
# set_data_value(opcua.ua.uatypes.Variant(value = value.tolist(), varianttype=opcua.ua.VariantType.Int32))
if self.dim_y != 0:
# spectrum
v = numpy.concatenate(value)
self.node.set_data_value(opcua.ua.uatypes.Variant(value=v.tolist(), varianttype=self.ua_type))
elif self.dim_x != 1:
#scalar
self.node.set_data_value(opcua.ua.uatypes.Variant(value=value.tolist(), varianttype=self.ua_type))
else:
self.node.set_data_value(opcua.ua.uatypes.Variant(value=value, varianttype=self.ua_type))
# -*- coding: utf-8 -*-
#
# This file is part of the PCC project
#
#
#
# Distributed under the terms of the APACHE license.
# See LICENSE.txt for more info.
""" PCC Device Server for LOFAR2.0
"""
# PyTango imports
from tango.server import run
from tango.server import device_property
# Additional import
from clients.SNMP_client import SNMP_client
from src.attribute_wrapper import *
from src.hardware_device import *
__all__ = ["example_device", "main"]
class example_device(hardware_device):
"""
**Properties:**
- Device Property
SNMP_community
- Type:'DevString'
SNMP_host
- Type:'DevULong'
SNMP_timeout
- Type:'DevDouble'
"""
# -----------------
# Device Properties
# -----------------
# SNMP_community = device_property(
# dtype='DevString',
# mandatory=True
# )
#
# SNMP_host = device_property(
# dtype='DevString',
# mandatory=True
# )
#
# SNMP_timeout = device_property(
# dtype='DevDouble',
# mandatory=True
# )
SNMP_community = b"public"
SNMP_host = "127.0.0.1"
SNMP_timeout = 5.0
# ----------
# Attributes
# ----------
# simple scalar - description
attr1 = attribute_wrapper(comms_annotation={"oids": "1.3.6.1.2.1.1.1.0"}, datatype=numpy.str_, access=AttrWriteType.READ_WRITE)
# simple scalar uptime
attr2 = attribute_wrapper(comms_annotation={"oids": "1.3.6.1.2.1.1.3.0"}, datatype=numpy.int64, access=AttrWriteType.READ_WRITE)
# simple scalar with name
attr3 = attribute_wrapper(comms_annotation={"oids": "1.3.6.1.2.1.1.5.0"}, datatype=numpy.str_, access=AttrWriteType.READ_WRITE)
#spectrum with all elements
attr4 = attribute_wrapper(comms_annotation={"oids": ["1.3.6.1.2.1.2.2.1.1.1", "1.3.6.1.2.1.2.2.1.1.2", "1.3.6.1.2.1.2.2.1.1.3"]}, dims=(3,), datatype=numpy.int64)
#inferred spectrum
attr5 = attribute_wrapper(comms_annotation={"oids": ".1.3.6.1.2.1.2.2.1.1"}, dims=(3,), datatype=numpy.int64)
def always_executed_hook(self):
"""Method always executed before any TANGO command is executed."""
pass
def delete_device(self):
"""Hook to delete resources allocated in init_device.
This method allows for any memory or other resources allocated in the
init_device method to be released. This method is called by the device
destructor and by the device Init command (a Tango built-in).
"""
self.debug_stream("Shutting down...")
self.Off()
self.debug_stream("Shut down. Good bye.")
# --------
# overloaded functions
# --------
def initialise(self):
""" user code here. is called when the state is set to STANDBY """
# set up the SNMP ua client
self.snmp_manager = SNMP_client(self.SNMP_community, self.SNMP_host, self.SNMP_timeout, self.Fault, self)
# map the attributes to the OPC ua comm client
for i in self.attr_list():
i.set_comm_client(self.snmp_manager)
self.snmp_manager.start()
# --------
# Commands
# --------
# ----------
# Run server
# ----------
def main(args=None, **kwargs):
"""Main function of the PCC module."""
return run((example_device,), args=args, **kwargs)
if __name__ == '__main__':
main()
from .v1 import SNMPv1
versions = {
1: SNMPv1,
}
def Manager(*args, version=1, **kwargs):
try:
cls = versions[version]
except KeyError as e:
msg = "'version' must be one of {}".format(list(versions.keys()))
raise ValueError(msg) from e
return cls(*args, **kwargs)
# used to indicate that a string cannot be decoded because it violates encoding rules
class EncodingError(Exception):
pass
# used to indicate that a response violates the protocol in some way
class ProtocolError(Exception):
pass
class Timeout(Exception):
pass
__all__ = ['RWLock']
from threading import Lock
# returns a pair of objects, (r, w), which constitute a
# writer-preferred reader/writer lock
def RWLock():
r = Lock()
w = Lock()
return RLock(r, w), WLock(r, w)
class ContextLock:
def __enter__(self):
self.acquire()
def __exit__(self, *args, **kwargs):
self.release()
class RLock(ContextLock):
def __init__(self, r, w):
self.r = r
self.w = w
self.mutex = Lock()
self.queue = Lock()
self.count = 0
def acquire(self):
with self.queue:
with self.r:
with self.mutex:
if not self.count:
self.w.acquire()
self.count += 1
def release(self):
with self.mutex:
self.count -= 1
if not self.count:
self.w.release()
class WLock(ContextLock):
def __init__(self, r, w):
self.r = r
self.w = w
self.mutex = Lock()
self.count = 0
def acquire(self):
with self.mutex:
if not self.count:
self.r.acquire()
self.count += 1
self.w.acquire()
def release(self):
self.w.release()
with self.mutex:
self.count -= 1
if not self.count:
self.r.release()
__all__ = [
'ASN1', 'INTEGER', 'OCTET_STRING', 'NULL', 'OID', 'SEQUENCE', 'UNSIGNED',
'Counter32', 'Gauge32', 'TimeTicks', 'Integer32', 'Counter64', 'IpAddress',
'VarBind', 'VarBindList', 'PDU', 'GetRequestPDU', 'GetNextRequestPDU',
'GetResponsePDU', 'SetRequestPDU', 'Message',
]
from copy import copy
import socket
from .exceptions import EncodingError, ProtocolError
def unpack(obj):
if len(obj) < 2:
raise EncodingError("object encoding is too short")
dtype = obj[0]
l = obj[1]
index = 2
if l & 0x80:
index += l & 0x7f
if len(obj) < index:
raise EncodingError("Long form length field is incomplete")
l = 0
for num in obj[2:index]:
l <<= 8
l += num
if len(obj) < index + l:
raise EncodingError("Invalid length field: object encoding too short")
return dtype, obj[index:index+l], obj[index+l:]
def length(l):
if l < 0x80:
return bytes([l])
bytearr = bytearray()
while l:
bytearr.append(l & 0xff)
l >>= 8
# this works as long as (l < 2^1008), which is super big
bytearr.append(len(bytearr) | 0x80)
return bytes(reversed(bytearr))
class ASN1:
def __init__(self, value=None, encoding=None):
self._encoding = encoding
self._value = value
def __repr__(self):
return "{}({})".format(self.__class__.__name__, self)
@classmethod
def copy(cls, obj):
return cls(encoding=obj.encoding)
@staticmethod
def deserialize(obj, cls=None, leftovers=False):
dtype, encoding, tail = unpack(obj)
if tail and not leftovers:
raise EncodingError("Unexpected trailing bytes")
if cls is None:
try:
cls = types[dtype]
except KeyError as e:
message = "Unknown type: '0x{:02x}'".format(dtype)
raise ProtocolError(message) from e
elif dtype != cls.TYPE:
message = "Expected type '0x{:02x}'; got '0x{:02x}'"
message = message.format(cls.TYPE, dtype)
raise ProtocolError(message)
obj = cls(encoding=encoding)
return (obj, tail) if leftovers else obj
def serialize(self):
return bytes([self.TYPE]) + length(len(self.encoding)) + self.encoding
# The following methods must be overwritten for sequence types
def __bool__(self):
return bool(self.value)
def __eq__(self, other):
return self.value == other
def __ge__(self, other):
return self.value >= other
def __gt__(self, other):
return self.value > other
def __le__(self, other):
return self.value <= other
def __lt__(self, other):
return self.value < other
def __ne__(self, other):
return self.value != other
def __str__(self):
return repr(self.value)
def poke(self):
self.value
### Primitive types ###
class INTEGER(ASN1):
SIGNED = True
@property
def encoding(self):
if self._encoding is None:
encoding = bytearray()
x = self._value
# do - while
while True:
encoding.append(x & 0xff)
x >>= 8
if x in (0, -1):
break
self._encoding = bytes(reversed(encoding))
return self._encoding
@property
def value(self):
if self._value is None:
negative = self.SIGNED and bool(self._encoding[0] & 0x80)
x = 0
for byte in self._encoding:
x <<= 8
x |= byte
if negative:
bits = 8 * len(self._encoding)
self._value = -(~x + (1 << bits) + 1)
else:
self._value = x
return self._value
class OCTET_STRING(ASN1):
@property
def encoding(self):
if self._encoding is None:
self._encoding = self._value
return self._encoding
@property
def value(self):
if self._value is None:
self._value = self._encoding
return self._value
class NULL(ASN1):
def __init__(self, value=None, encoding=None):
if encoding:
raise EncodingError("Non-null encoding for NULL type")
elif value is not None:
raise ValueError("NULL cannot have non-null value")
def __str__(self):
return ""
@property
def encoding(self):
return b''
@property
def value(self):
return None
class OID(ASN1):
@property
def encoding(self):
if self._encoding is None:
if self._value[0] == '.':
self._value = self.value[1:]
segments = [int(segment) for segment in self._value.split('.')]
if len(segments) > 1:
segments[1] += segments[0] * 40
segments = segments[1:]
encoding = bytearray()
for num in segments:
bytearr = bytearray()
while num > 0x7f:
bytearr.append(num & 0x7f)
num >>= 7
bytearr.append(num)
for i in range(1, len(bytearr)):
bytearr[i] |= 0x80
bytearr.reverse()
encoding += bytearr
self._encoding = bytes(encoding)
return self._encoding
@property
def value(self):
if self._value is None:
encoding = self._encoding
first = encoding[0]
oid = [str(num) for num in divmod(first, 40)]
val = 0
for byte in encoding[1:]:
val |= byte & 0x7f
if byte & 0x80:
val <<= 7
else:
oid.append(str(val))
val = 0
if val:
raise EncodingError("OID ended in a byte with bit 7 set")
self._value = '.'.join(oid)
return self._value
class SEQUENCE(ASN1):
EXPECTED = None
def __init__(self, *values, encoding=None):
self.expected = copy(self.EXPECTED)
self._encoding = encoding
self._values = values or None
def __bool__(self):
return bool(self.values)
def __eq__(self, other):
return self.values == other
def __ge__(self, other):
return self.values >= other
def __gt__(self, other):
return self.values > other
def __le__(self, other):
return self.values <= other
def __lt__(self, other):
return self.values < other
def __ne__(self, other):
return self.values != other
def __str__(self):
return repr(self)
def __repr__(self, depth=0):
string = "{}{}:\n".format('\t'*depth, self.__class__.__name__)
depth += 1
for entry in self.values:
if isinstance(entry, SEQUENCE):
string += entry.__repr__(depth=depth)
else:
string += "{}{}: {}\n".format(
'\t'*depth,
entry.__class__.__name__,
entry
)
return string
def poke(self):
for val in self.values:
val.poke()
@property
def encoding(self):
if self._encoding is None:
encodings = [None] * len(self.values)
for i in range(len(self.values)):
encodings[i] = self.values[i].serialize()
self._encoding = b''.join(encodings)
return self._encoding
@property
def values(self):
if self._values is None:
definite = isinstance(self.expected, list)
sequence = []
encoding = self._encoding
while encoding:
if definite:
try:
cls = self.expected[len(sequence)]
except IndexError as e:
message = "{} has too many elements"
message = message.format(self.__class__.__name__)
raise ProtocolError(message) from e
else:
cls = self.expected
obj, encoding = ASN1.deserialize(encoding, cls=cls, leftovers=True)
sequence.append(obj)
if definite and len(sequence) < len(self.expected):
message = "{} has too few elements"
message = message.format(self.__class__.__name__)
raise ProtocolError(message)
self._values = tuple(sequence)
return self._values
### Composed types ###
class UNSIGNED(INTEGER):
SIGNED = False
class Counter32(UNSIGNED):
pass
class Gauge32(UNSIGNED):
pass
class TimeTicks(UNSIGNED):
pass
class Integer32(INTEGER):
pass
class Counter64(UNSIGNED):
pass
class IpAddress(OCTET_STRING):
@property
def encoding(self):
if self._encoding is None:
self._encoding = socket.inet_aton(self._value)
return self._encoding
@property
def value(self):
if self._value is None:
if len(self._encoding) == 4:
self._value = socket.inet_ntoa(self._encoding)
else:
raise ProtocolError("IP Address must be 4 bytes long")
return self._value
class VarBind(SEQUENCE):
EXPECTED = [
OID,
None,
]
def __init__(self, *args, **kwargs):
super(VarBind, self).__init__(*args, **kwargs)
self.error = None
@property
def name(self):
return self.values[0]
@property
def value(self):
return self.values[1]
class VarBindList(SEQUENCE):
EXPECTED = VarBind
def __getitem__(self, index):
return self.values[index]
def __iter__(self):
return iter(self.values)
def __len__(self):
return len(self.values)
class PDU(SEQUENCE):
EXPECTED = [
INTEGER,
INTEGER,
INTEGER,
VarBindList,
]
def __init__(self, request_id=0, error_status=0, error_index=0, vars=None, encoding=None):
values = (
INTEGER.copy(UNSIGNED(request_id)),
INTEGER(error_status),
INTEGER(error_index),
vars,
) if encoding is None else ()
super(PDU, self).__init__(*values, encoding=encoding)
@property
def request_id(self):
return self.values[0]
@property
def error_status(self):
return self.values[1]
@property
def error_index(self):
return self.values[2]
@property
def vars(self):
return self.values[3]
class GetRequestPDU(PDU):
pass
class GetNextRequestPDU(PDU):
pass
class GetResponsePDU(PDU):
pass
class SetRequestPDU(PDU):
pass
class Message(SEQUENCE):
EXPECTED = [
INTEGER,
OCTET_STRING,
GetResponsePDU,
]
def __init__(self, version=0, community=b'public', data=None, encoding=None):
values = (
INTEGER(version),
OCTET_STRING(community),
data,
) if encoding is None else ()
super(Message, self).__init__(*values, encoding=encoding)
@property
def version(self):
return self.values[0]
@property
def community(self):
return self.values[1]
@property
def data(self):
return self.values[2]
types = {
0x02: INTEGER,
0x04: OCTET_STRING,
0x05: NULL,
0x06: OID,
0x30: SEQUENCE,
0x40: IpAddress,
0x41: Counter32,
0x42: Gauge32,
0x43: TimeTicks,
0x44: Integer32,
0x46: Counter64,
0xa0: GetRequestPDU,
0xa1: GetNextRequestPDU,
0xa2: GetResponsePDU,
0xa3: SetRequestPDU,
}
for dtype, cls in types.items():
cls.TYPE = dtype
from binascii import hexlify
from collections import OrderedDict
import logging
import os
import random
import select
import socket
import threading
import time
from ..exceptions import EncodingError, ProtocolError, Timeout
from ..mutex import RWLock
from ..types import *
from .exceptions import TooBig, NoSuchName, BadValue, ReadOnly, GenErr
log = logging.getLogger(__name__)
DUMMY_EVENT = threading.Event()
DUMMY_EVENT.set()
ERRORS = {
1: TooBig,
2: NoSuchName,
3: BadValue,
4: ReadOnly,
5: GenErr,
}
PORT = 161
RECV_SIZE = 65507
MAX_REQUEST_ID = 0xffffffff
VERSION = 0
WINDOWS = (os.name == "nt")
if WINDOWS:
import ctypes
from ctypes.wintypes import HANDLE, DWORD, BOOL
import msvcrt
ERROR_NO_DATA = 232
LPDWORD = ctypes.POINTER(DWORD)
SELECT_TIMEOUT = 1
def errcheck (result, func, args):
if not result:
raise WinError (ctypes.get_last_error())
def setblocking (fd, blocking):
handle = msvcrt.get_osfhandle(fd)
mode = ctypes.byref(DWORD(0 if blocking else 1))
kernel32 = ctypes.WinDLL("kernel32", use_last_error=True)
kernel32.SetNamedPipeHandleState.argtypes = [HANDLE, LPDWORD, LPDWORD, LPDWORD]
kernel32.SetNamedPipeHandleState.restype = BOOL
kernel32.SetNamedPipeHandleState.errcheck = errcheck
return kernel32.SetNamedPipeHandleState(handle, mode, None, None)
class PendTable:
def __init__(self):
self.lock = threading.Lock()
# The two events signal the arrival of the value for the oid itself,
# or the variable returned by a GetNext request (respectively)
# {
# <oid>: [<Event>, <Event>],
# ...
# }
self.oids = {}
# Used by set() to make sure multiple set requests to the same OID
# do not overlap in time
# {
# <oid>: <Event>,
# ...
# }
self.sets = {}
if WINDOWS:
def _close (fd):
os.close(fd)
def _done (_, fd):
try:
os.read(fd, 1)
except WindowsError:
if ctypes.GetLastError() != ERROR_NO_DATA:
raise
else:
return False
else:
return True
def _select (sock, _):
return select.select([sock], [], [], SELECT_TIMEOUT)[0]
def _setup (fd):
setblocking(fd, False)
return fd
else:
def _close (pipe):
pipe.close()
def _done (ready, pipe):
return pipe in ready
def _select (sock, pipe):
return select.select([sock, pipe], [], [])[0]
def _setup (fd):
return os.fdopen(fd)
# background thread to process responses
def _listen_thread(sock, pipe, requests, rlock, data, dlock, port=PORT):
pipe = _setup(pipe)
while True:
# wait for data on sock or pipe
ready = _select(sock, pipe)
if _done(ready, pipe):
# exit from this thread
# don't bother processing any more responses; the calling
# application has all the data they need
break
for s in ready:
# listen for UDP packets from the correct port
packet, (host, p) = s.recvfrom(RECV_SIZE)
if p != port:
continue
try:
# convert bytes to Message object
message = ASN1.deserialize(packet, cls=Message)
# ignore garbage packets
if message.version != VERSION:
continue
# force a full parse; invalid packet will raise an error
message.poke()
except (EncodingError, ProtocolError) as e:
# this should take care of filtering out invalid traffic
log.debug("{}: {}: {}".format(
e.__class__.__name__, e, hexlify(packet).decode()
))
continue
request_id = message.data.request_id.value
try:
with rlock:
request, event = requests[request_id][:2]
except KeyError:
# ignore responses for which there was no request
msg = "Received unexpected response from {}: {}"
log.debug(msg.format(host, hexlify(packet).decode()))
continue
# while we don't explicitly check every possible protocol violation
# this one would cause IndexErrors below, which I'd rather avoid
if len(message.data.vars) != len(request.data.vars):
msg = "VarBindList length mismatch:\n(Request) {}(Response) {}"
log.error(msg.format(request, message))
continue
requests.pop(request_id)
next = isinstance(request.data, GetNextRequestPDU)
error = None
error_status = message.data.error_status.value
if error_status != 0:
log.debug(message.data)
error_index = message.data.error_index.value
try:
cls = ERRORS[error_status]
except KeyError:
msg = "Invalid error status: {}"
error = ProtocolError(msg.format(error_status))
else:
try:
varbind = message.data.vars[error_index-1]
except IndexError:
msg = "Invalid error index: {}"
error = ProtocolError(msg.format(error_index))
else:
error = cls(varbind.name.value)
with dlock:
try:
host_data = data[host]
except KeyError:
host_data = {}
data[host] = host_data
for i, varbind in enumerate(message.data.vars):
# won't make a difference if error is None
varbind.error = error
requested = request.data.vars[i].name.value
oid = varbind.name.value
if next:
try:
host_data[requested][1] = oid
except KeyError:
host_data[requested] = [None, oid]
elif requested != oid:
msg = "OID ({}) does not match requested ({})"
log.warning(msg.format(oid, requested))
# this will cause a ProtocolError to be raised in get()
# However, if this data is never accessed, the error
# will go unnoticed.
# Assuming, however, that the agent is correctly
# implemented and the channel is secure, this should
# never happen
try:
host_data[requested][0] = None
except KeyError:
host_data[requested] = [None, None]
# update data table
try:
host_data[oid][0] = varbind
except KeyError:
host_data[oid] = [varbind, None]
msg = "Done processing response from {} (ID={})"
log.debug(msg.format(host, request_id))
# alert the main thread that the data is ready
event.set()
_close(pipe)
log.debug("Listener thread exiting")
def _monitor_thread(sock, done, requests, rlock, data, dlock, port=PORT, resend=1):
delay = 0
while not done.wait(timeout=delay):
with rlock:
try:
ID = next(iter(requests))
except StopIteration:
delay = resend
else:
timestamp = requests[ID][3]
diff = time.time() - timestamp
if diff >= resend:
delay = 0
message, event, host, _, count = requests.pop(ID)
if count:
timestamp += resend
requests[ID] = (
message, event, host, timestamp, count-1
)
else:
delay = 1-diff
if delay == 0:
if count:
msg = "Resending to {} (ID={})"
log.debug(msg.format(host, message.data.request_id))
sock.sendto(message.serialize(), (host, port))
else:
with dlock:
msg = "Request to {} timed out (ID={})"
log.debug(msg.format(host, message.data.request_id))
for varbind in message.data.vars:
varbind.error = Timeout(varbind.name.value)
oid = varbind.name.value
try:
host_data = data[host]
except KeyError:
host_data = {}
data[host] = host_data
try:
_, next_oid = host_data[oid]
except KeyError:
next_oid = None
# causes GETNEXT requests to register the timeout
if isinstance(message.data, GetNextRequestPDU):
next_oid = oid
host_data[oid] = [varbind, next_oid]
event.set()
log.debug("Monitor thread exiting")
class SNMPv1:
def __init__(self, community, rwcommunity=None, port=PORT, resend=1):
self.rocommunity = community
self.rwcommunity = rwcommunity or community
self.port = port
self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self._sock.setblocking(False)
self._sock.bind(('', 0))
# used to shut down background threads
r, w = os.pipe()
self._write_pipe = os.fdopen(w, 'w')
self._closed = threading.Event()
# counting by an odd number will hit every
# request id once before repeating
self._count_by = random.randint(0, MAX_REQUEST_ID//2) * 2 + 1
self._next_id = self._count_by
# This is an OrderedDict so the monitoring thread can iterate through
# them in order
# <timestamp> is the timestamp of the most recent transmission
# <count> is the number of remaining re-transmits before Timeout
# {
# <request_id>: (<Message>, <Event>, <host>, <timestamp>, <count>),
# ...
# }
self._requests = OrderedDict()
self._rlock = threading.Lock()
# table of pending requests (prevents re-sending packets unnecessarily)
# {
# <host_ip>: <PendTable>,
# ...
# }
self._pending = {}
self._plock = threading.Lock()
# table of responses
# {
# <host_ip>: {
# <oid>: [
# <VarBind>,
# <next_oid>,
# ],
# ...
# },
# ...
# }
self._data = {}
self._drlock, self._dwlock = RWLock()
self._listener = threading.Thread(
target=_listen_thread,
args=(
self._sock,
r,
self._requests,
self._rlock,
self._data,
self._dwlock,
),
kwargs={"port":port},
)
self._listener.start()
self._monitor = threading.Thread(
target=_monitor_thread,
args=(
self._sock,
self._closed,
self._requests,
self._rlock,
self._data,
self._dwlock,
),
kwargs={
"port": port,
"resend": resend,
},
)
self._monitor.start()
def close(self):
log.debug("Sending shutdown signal to helper threads")
self._closed.set()
self._write_pipe.write('\0')
self._write_pipe.flush()
self._listener.join()
self._monitor.join()
log.debug("All helper threads done")
self._write_pipe.close()
self._sock.close()
self._write_pipe = None
self._sock = None
def __enter__(self):
return self
def __exit__(self, *args):
if not self._closed.is_set():
self.close()
def _request_id(self):
request_id = self._next_id
self._next_id = (request_id + self._count_by) & MAX_REQUEST_ID
return request_id
def get(self, host, *oids, community=None, block=True, timeout=10,
refresh=False, next=False):
# this event will be stored in the pending table under this request ID
# the _listener_thread will signal when the data is ready
main_event = threading.Event()
# store the first error found on a cached VarBind and raise it only
# after the request has been sent for any other oids there may be
error = None
# used for blocking calls to wait for all responses
events = set()
# set of oids that are neither in self._data nor self._pending
send = set()
# return value (values[i] corresponds to oids[i])
values = [None] * len(oids)
with self._plock:
try:
host_pending = self._pending[host]
except KeyError:
host_pending = PendTable()
self._pending[host] = host_pending
with self._drlock:
try:
host_data = self._data[host]
except KeyError:
host_data = {}
# acquiring the lock all the way out here should minimize the number of
# packets sent, even if this object is being shared by multiple threads
with host_pending.lock:
for i, oid in enumerate(oids):
if not refresh:
try:
event = host_pending.oids[oid][int(next)]
except KeyError:
pass
else:
# request has been sent already and is pending
if event and not event.is_set():
events.add(event)
# don't fetch cached value, don't re-send request
continue
try:
# TODO: make a separate lock for each host's data
with self._drlock:
value, next_oid = host_data[oid]
if next:
value, _ = host_data[next_oid]
except KeyError:
pass
else:
# cached value found
if value is not None:
# raise any errors after the request is sent
error = error or value.error
# set return value
values[i] = value
continue
# add this item to the pending table
try:
host_pending.oids[oid][int(next)] = main_event
except KeyError:
if next:
host_pending.oids[oid] = [None, main_event]
else:
host_pending.oids[oid] = [main_event, None]
# make a note to include this OID in the request
send.add(oid)
# send any requests that are not found to be pending
if send:
events.add(main_event)
pdu_type = GetNextRequestPDU if next else GetRequestPDU
pdu = pdu_type(
request_id=self._request_id(),
vars=VarBindList(
*[VarBind(OID(oid), NULL()) for oid in send]
),
)
# assign request_id variable this way rather than directly from
# self._request_id() because self._request_id() returns unsigned
# values, whereas this method returns signed values, and the key
# here has to match what is used in the _listener_thread()
request_id = pdu.request_id.value
message = Message(
version = VERSION,
data = pdu,
community = community or self.rocommunity,
)
with self._rlock:
self._requests[request_id] = (
message, main_event, host, time.time(), timeout-1
)
self._sock.sendto(message.serialize(), (host, self.port))
log.debug("Sent request to {} (ID={})".format(host, request_id))
if error is not None:
raise error
if not block:
return values
# wait for all requested oids to receive a response
for event in events:
event.wait()
# the data table should now be all up to date
with self._drlock:
values = []
try:
host_data = self._data[host]
except KeyError:
# shouldn't get here, ProtocolError will be triggered below
host_data = {}
for oid in oids:
try:
value, next_oid = host_data[oid]
if next:
value, _ = host_data[next_oid]
except KeyError:
value = None
if value is None:
raise ProtocolError("Missing variable: {}".format(oid))
elif value.error is not None:
raise value.error
values.append(value)
return values
def get_next(self, *args, **kwargs):
kwargs['next'] = True
return self.get(*args, **kwargs)
def set(self, host, oid, value, community=None, block=True, timeout=10):
# wrap the value in an ASN1 type
if isinstance(value, int):
value = INTEGER(value)
elif value is None:
value = NULL()
elif isinstance(value, ASN1):
pass
else:
if isinstance(value, str):
value = value.encode()
value = OCTET_STRING(value)
# create PDU
pdu = SetRequestPDU(
request_id=self._request_id(),
vars=VarBindList(VarBind(OID(oid), value)),
)
request_id = pdu.request_id.value
message = Message(
version = VERSION,
data = pdu,
community = community or self.rwcommunity,
)
# get PendTable for this host
with self._plock:
try:
host_pending = self._pending[host]
except KeyError:
host_pending = PendTable()
self._pending[host] = host_pending
# used to wait for the previous set request to complete
event = DUMMY_EVENT
# signaled when _listen_thread processes the response
main_event = threading.Event()
# wait for any pending requests to complete before sending
pend_event = None
# only allow one outstanding set request at a time
while pend_event is None:
# loop until we can put main_event in host_pending.oids
event.wait()
with host_pending.lock:
try:
event = host_pending.sets[oid]
except KeyError:
event = DUMMY_EVENT
# event will not be set if another thread's set request
# acquires the lock first and stores its main_event to .sets
if event.is_set():
try:
pend_event, next_oid = host_pending.oids[oid]
except KeyError:
pend_event, next_oid = DUMMY_EVENT, None
host_pending.oids[oid] = [main_event, next_oid]
host_pending.sets[oid] = main_event
# wait for pending requests to be serviced
pend_event.wait()
with self._rlock:
self._requests[request_id] = (
message, main_event, host, time.time(), timeout-1
)
self._sock.sendto(message.serialize(), (host, self.port))
msg = "SET request sent to {} (ID={}):\n{}"
log.debug(msg.format(host, request_id, pdu))
if not block:
return
# no need to duplicate code; just call self.get()
return self.get(host, oid, block=True)
def walk(self, host, oid, **kwargs):
start = oid
while True:
try:
var, = self.get_next(host, oid, block=True, **kwargs)
except NoSuchName:
break
oid = var.name.value
if not oid.startswith(start):
break
# send now to speed access on the next iteration
self.get_next(host, oid, block=False, **kwargs)
yield [var]
class StatusError(Exception):
pass
class TooBig(StatusError):
pass
class NoSuchName(StatusError):
pass
class BadValue(StatusError):
pass
class ReadOnly(StatusError):
pass
class GenErr(StatusError):
pass
get:
create main event
get pend table for host or create one
get data table for host or create dummy
with pend table lock
for each oid in the request
if not refresh
if pending
continue
if found in cache
grab cached result
continue
add main event to the pend table
add oid to request
if there are oids to be sent
construct message
add message to requests table
send message
if an oid had an error
raise it now
if not waiting for response
return the values you do have
wait for all events
with data table read lock
for each oid
grab the value
- make sure it is present
- make sure there are no errors
return the values
listen thread:
check port number
decode message
pull request id
find the corresponding request
make sure it has the right number of varbinds
remove request from table
check error status
with data table write lock
get/create entry for host
for each varbind
give it the error found in the request error field
find the oid requested
make sure it matches the request
save varbind to data table
set the request event
monitor thread:
wait for next stale request, done, or 1 second (whichever comes first)
grab next request
if it is stale
if it has not timed out
resend it
else
set varbind error to timeout
signal event
......@@ -14,79 +14,83 @@
# PyTango imports
from tango.server import run
from tango.server import device_property
from tango import DevState
# Additional import
from clients.test_client import example_client
from clients.SNMP_client import SNMP_client
from src.attribute_wrapper import *
from src.hardware_device import *
__all__ = ["test_device", "main"]
__all__ = ["example_device", "main"]
class example_device(hardware_device):
class test_device(hardware_device):
# -----------------
# Device Properties
# -----------------
OPC_Server_Name = device_property(
dtype='DevString',
)
OPC_Server_Port = device_property(
dtype='DevULong',
)
OPC_Time_Out = device_property(
dtype='DevDouble',
)
SNMP_community = b"public"
SNMP_host = "127.0.0.1"
SNMP_timeout = 5.0
# ----------
# Attributes
# ----------
bool_scalar_R = attribute_wrapper(comms_annotation="numpy.bool_ type read scalar", datatype=numpy.bool_)
bool_scalar_RW = attribute_wrapper(comms_annotation="numpy.bool_ type read/write scalar", datatype=numpy.bool_, access=AttrWriteType.READ_WRITE)
int32_spectrum_R = attribute_wrapper(comms_annotation="numpy.int32 type read spectrum (len = 8)", datatype=numpy.int32, dims=(8,))
int32_spectrum_RW = attribute_wrapper(comms_annotation="numpy.int32 type read spectrum (len = 8)", datatype=numpy.int32, dims=(8,),
access=AttrWriteType.READ_WRITE)
# simple scalar
attr1 = attribute_wrapper(comms_annotation={"oids": "1.3.6.1.2.1.1.6.0"}, datatype=numpy.bool_, access=AttrWriteType.READ_WRITE)
# simple scalar with host
attr2 = attribute_wrapper(comms_annotation={"oids": "1.3.6.1.2.1.1.5.0"}, datatype=numpy.bool_, access=AttrWriteType.READ_WRITE)
#spectrum with all elements
attr3 = attribute_wrapper(comms_annotation={"oids": ["1.3.6.1.2.1.1.5.1", "1.3.6.1.2.1.1.5.2", "1.3.6.1.2.1.1.5.3"]}, dims=(3,), datatype=numpy.bool_)
#inferred spectrum
attr4 = attribute_wrapper(comms_annotation={"oids": ["1.3.6.1.2.1.1.5.0"]}, dims=(3,), datatype=numpy.bool_)
double_image_R = attribute_wrapper(comms_annotation="numpy.double type read image (dims = 2x8)", datatype=numpy.double, dims=(2, 8))
double_image_RW = attribute_wrapper(comms_annotation="numpy.double type read/write image (dims = 8x2)", datatype=numpy.double, dims=(8, 2),
access=AttrWriteType.READ_WRITE)
int32_scalar_R = attribute_wrapper(comms_annotation="numpy.int32 type read scalar", datatype=numpy.int32)
uint16_spectrum_RW = attribute_wrapper(comms_annotation="numpy.uint16 type read/write spectrum (len = 8)", datatype=numpy.uint16, dims=(8,),
access=AttrWriteType.READ_WRITE)
float32_image_R = attribute_wrapper(comms_annotation="numpy.float32 type read image (dims = 8x2)", datatype=numpy.float32, dims=(8, 2))
uint8_image_RW = attribute_wrapper(comms_annotation="numpy.uint8 type read/write image (dims = 2x8)", datatype=numpy.uint8, dims=(2, 8),
access=AttrWriteType.READ_WRITE)
def always_executed_hook(self):
"""Method always executed before any TANGO command is executed."""
pass
def delete_device(self):
"""Hook to delete resources allocated in init_device.
This method allows for any memory or other resources allocated in the
init_device method to be released. This method is called by the device
destructor and by the device Init command (a Tango built-in).
"""
self.debug_stream("Shutting down...")
self.Off()
self.debug_stream("Shut down. Good bye.")
# --------
# overloaded functions
# --------
def initialise(self):
""" user code here. is called when the sate is set to INIT """
"""Initialises the attributes and properties of the PCC."""
""" user code here. is called when the state is set to STANDBY """
self.set_state(DevState.INIT)
# set up the SNMP ua client
self.snmp_manager = SNMP_client(self.SNMP_community, self.SNMP_host, self.SNMP_timeout, self.Fault, self)
# set up the OPC ua client
self.example_client = example_client(self.Fault, self)
# map an access helper class
# map the attributes to the OPC ua comm client
for i in self.attr_list():
i.set_comm_client(self.example_client)
i.set_comm_client(self.snmp_manager)
self.snmp_manager.start()
self.example_client.start()
# --------
# Commands
# --------
# ----------
# Run server
# ----------
def main(args=None, **kwargs):
"""Main function of the example module."""
return run((test_device,), args=args, **kwargs)
"""Main function of the PCC module."""
return run((example_device,), args=args, **kwargs)
if __name__ == '__main__':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment