From f6ae80e5cae0252df1b63e91b24f5516344fcc89 Mon Sep 17 00:00:00 2001 From: Auke Klazema <klazema@astron.nl> Date: Fri, 7 Jun 2019 08:30:31 +0000 Subject: [PATCH] SW-705: Add mixin to reduce code duplication in rpc client classes --- LCS/Messaging/python/messaging/rpc_service.py | 23 ++++- MAC/Services/src/observation_control_rpc.py | 29 ++---- .../MoMQueryServiceClient/momqueryrpc.py | 93 ++++++++----------- .../test/t_momqueryservice.py | 48 ++++++---- 4 files changed, 94 insertions(+), 99 deletions(-) diff --git a/LCS/Messaging/python/messaging/rpc_service.py b/LCS/Messaging/python/messaging/rpc_service.py index e1e55473338..229907a78a2 100644 --- a/LCS/Messaging/python/messaging/rpc_service.py +++ b/LCS/Messaging/python/messaging/rpc_service.py @@ -321,7 +321,26 @@ class RPC(): raise Exception(answer.error_message) -__all__ = ['ServiceMessageHandler', 'Service', 'RPC'] + +class RPCContextManagerMixin: + def __init__(self): + self._rpc = None + + def open(self): + self._rpc.open() + + def close(self): + self._rpc.close() + + def __enter__(self): + self.open() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + +__all__ = ['ServiceMessageHandler', 'Service', 'RPC', 'RPCContextManagerMixin'] if __name__ == "__main__": logging.basicConfig(format='%(levelname)s %(threadName)s %(message)s', level=logging.DEBUG) @@ -329,4 +348,4 @@ if __name__ == "__main__": # run the doctests in this module import doctest - doctest.testmod(verbose=True, report=True) \ No newline at end of file + doctest.testmod(verbose=True, report=True) diff --git a/MAC/Services/src/observation_control_rpc.py b/MAC/Services/src/observation_control_rpc.py index 8935cc3ac00..48f3fecd75c 100644 --- a/MAC/Services/src/observation_control_rpc.py +++ b/MAC/Services/src/observation_control_rpc.py @@ -20,8 +20,7 @@ import logging -from lofar.messaging.rpc_service import RPC -from lofar.messaging import DEFAULT_BROKER, DEFAULT_BUSNAME +from lofar.messaging.rpc_service import RPC, RPCContextManagerMixin from lofar.mac.config import DEFAULT_OBSERVATION_CONTROL_SERVICE_NAME ''' Simple RPC client for Service ObservationControl2 @@ -30,26 +29,10 @@ from lofar.mac.config import DEFAULT_OBSERVATION_CONTROL_SERVICE_NAME logger = logging.getLogger(__name__) -class ObservationControlRPCClient(): - def __init__(self, - exchange=DEFAULT_BUSNAME, - servicename=DEFAULT_OBSERVATION_CONTROL_SERVICE_NAME, - broker=DEFAULT_BROKER, - timeout=120): - self.rpc = RPC(service_name=servicename, exchange=exchange, broker=broker, timeout=timeout) - - def open(self): - self.rpc.open() - - def close(self): - self.rpc.close() - - def __enter__(self): - self.open() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() +class ObservationControlRPCClient(RPCContextManagerMixin): + def __init__(self, rpc : RPC = RPC(service_name=DEFAULT_OBSERVATION_CONTROL_SERVICE_NAME)): + super().__init__() + self._rpc = rpc def abort_observation(self, sas_id): - return self.rpc.execute('AbortObservation', sas_id=sas_id) + return self._rpc.execute('AbortObservation', sas_id=sas_id) diff --git a/SAS/MoM/MoMQueryService/MoMQueryServiceClient/momqueryrpc.py b/SAS/MoM/MoMQueryService/MoMQueryServiceClient/momqueryrpc.py index d8c8e5564a4..90699485e67 100644 --- a/SAS/MoM/MoMQueryService/MoMQueryServiceClient/momqueryrpc.py +++ b/SAS/MoM/MoMQueryService/MoMQueryServiceClient/momqueryrpc.py @@ -21,7 +21,7 @@ import sys import logging import pprint from optparse import OptionParser -from lofar.messaging import RPC, DEFAULT_BROKER, DEFAULT_BUSNAME +from lofar.messaging import RPC, DEFAULT_BROKER, DEFAULT_BUSNAME, RPCContextManagerMixin from lofar.mom.momqueryservice.config import DEFAULT_MOMQUERY_SERVICENAME ''' Simple RPC client for Service momqueryservice @@ -30,32 +30,17 @@ from lofar.mom.momqueryservice.config import DEFAULT_MOMQUERY_SERVICENAME logger = logging.getLogger(__name__) -class MoMQueryRPC: - def __init__(self, exchange=DEFAULT_BUSNAME, - broker=DEFAULT_BROKER, - timeout=120): - self.rpc = RPC(exchange=exchange, service_name=DEFAULT_MOMQUERY_SERVICENAME, - broker=broker, timeout=timeout) - - def __enter__(self): - self.open() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() - - def open(self): - self.rpc.open() - - def close(self): - self.rpc.close() +class MoMQueryRPC(RPCContextManagerMixin): + def __init__(self, rpc : RPC = None): + super().__init__() + self._rpc = rpc or RPC(service_name=DEFAULT_MOMQUERY_SERVICENAME) def add_trigger(self, user_name, host_name, project_name, meta_data): logger.info("Requestion add_trigger for user_name: %s, host_name: %s, project_name: %s and " "meta_data: %s", user_name, host_name, project_name, meta_data) - row_id = self.rpc.execute('add_trigger', user_name=user_name, host_name=host_name, - project_name=project_name, meta_data=meta_data) + row_id = self._rpc.execute('add_trigger', user_name=user_name, host_name=host_name, + project_name=project_name, meta_data=meta_data) logger.info("Received add_trigger for user_name (%s), host_name(%s), project_name(%s) and " "meta_data(%s): %s", @@ -66,7 +51,7 @@ class MoMQueryRPC: def get_project_priority(self, project_name): logger.info("Requestion get_project_priority for project_name: %s", project_name) - priority = self.rpc.execute('get_project_priority', project_name=project_name) + priority = self._rpc.execute('get_project_priority', project_name=project_name) logger.info("Received get_project_priority for project_name (%s): %s", project_name, priority) @@ -80,7 +65,7 @@ class MoMQueryRPC: """ logger.info("Requesting allows_triggers for project_name: %s", project_name) - result = self.rpc.execute('allows_triggers', project_name=project_name) + result = self._rpc.execute('allows_triggers', project_name=project_name) logger.info("Received allows_triggers for project_name (%s): %s", project_name, result) @@ -96,8 +81,8 @@ class MoMQueryRPC: """ logger.info("Requesting authorized_add_with_status for user_name: %s project_name: %s " "job_type: %s status: %s", user_name, project_name, job_type, status) - result = self.rpc.execute('authorized_add_with_status', user_name=user_name, - project_name=project_name, job_type=job_type, status=status) + result = self._rpc.execute('authorized_add_with_status', user_name=user_name, + project_name=project_name, job_type=job_type, status=status) logger.info("Received authorized_add_with_status for user_name: %s project_name: %s " "job_type: %s status: %s result: %s", user_name, project_name, job_type, status, result) @@ -109,7 +94,7 @@ class MoMQueryRPC: :return: Boolean """ logger.info("Requesting folder: %s exists", folder) - result = self.rpc.execute('folder_exists', folder=folder) + result = self._rpc.execute('folder_exists', folder=folder) logger.info("Received folder exists: %s", result) return result @@ -119,7 +104,7 @@ class MoMQueryRPC: :return: Boolean """ logger.info("Requesting if project: %s is active", project_name) - result = self.rpc.execute('is_project_active', project_name=project_name) + result = self._rpc.execute('is_project_active', project_name=project_name) logger.info("Received Project is active: %s", result) return result @@ -129,7 +114,7 @@ class MoMQueryRPC: :return: Boolean """ logger.info("Requesting if user %s is an operator", user_name) - result = self.rpc.execute('IsUserOperator', user_name=user_name) + result = self._rpc.execute('IsUserOperator', user_name=user_name) logger.info("User %s is %san operator", user_name, 'not ' if result['is_operator'] is False else '') return result @@ -141,7 +126,7 @@ class MoMQueryRPC: :param user_name: string that contains the user's login name or None. :rtype dict with all triggers""" logger.info("Requesting triggers for user %s", user_name) - triggers = self.rpc.execute('get_triggers', user_name=user_name) + triggers = self._rpc.execute('get_triggers', user_name=user_name) logger.info("Received %d triggers for user %s", len(triggers), user_name) @@ -156,8 +141,8 @@ class MoMQueryRPC: :rtype dict with all triggers""" logger.info("Requesting trigger spec for user %s and trigger id " "%s", user_name, trigger_id) - trigger_spec = self.rpc.execute('get_trigger_spec', user_name=user_name, - trigger_id=trigger_id) + trigger_spec = self._rpc.execute('get_trigger_spec', user_name=user_name, + trigger_id=trigger_id) logger.info("Received a trigger spec with size %d for trigger id " "%s of user %s", len(trigger_spec['trigger_spec']), trigger_id, @@ -170,7 +155,7 @@ class MoMQueryRPC: :return: Integer or None """ logger.info("Requesting get_trigger_id for mom_id: %s", mom_id) - result = self.rpc.execute('get_trigger_id', mom_id=mom_id) + result = self._rpc.execute('get_trigger_id', mom_id=mom_id) logger.info("Received get_trigger_id: %s", result) return result @@ -180,7 +165,7 @@ class MoMQueryRPC: :return: (Integer, Integer) """ logger.info("Requesting get_trigger_quota for project: %s", project_name) - result = self.rpc.execute('get_trigger_quota', project_name=project_name) + result = self._rpc.execute('get_trigger_quota', project_name=project_name) logger.info("Received trigger quota: %s", result) return result @@ -192,7 +177,7 @@ class MoMQueryRPC: :return: (Integer, Integer) """ logger.info("Requesting update_trigger_quota for project: %s", project_name) - result = self.rpc.execute('update_trigger_quota', project_name=project_name) + result = self._rpc.execute('update_trigger_quota', project_name=project_name) logger.info("Received updated trigger quota: %s", result) return result @@ -203,7 +188,7 @@ class MoMQueryRPC: :return (Integer, Integer) """ logger.info("Requesting cancel_trigger for trigger id: %s | reason: %s", trigger_id, reason) - result = self.rpc.execute('cancel_trigger', trigger_id=trigger_id, reason=reason) + result = self._rpc.execute('cancel_trigger', trigger_id=trigger_id, reason=reason) logger.info("Requesting cancel_trigger for trigger id %s returned updated project trigger " "quota: %s", trigger_id, result) return result @@ -213,7 +198,7 @@ class MoMQueryRPC: :param mom_id :rtype dict with pi and contact author email addresses""" logger.info("Requesting get_project_details for mom_id: %s", mom_id) - result = self.rpc.execute('get_project_details', mom_id=mom_id) + result = self._rpc.execute('get_project_details', mom_id=mom_id) logger.info("Received get_project_details: %s", result) return result @@ -227,7 +212,7 @@ class MoMQueryRPC: ids_string = ', '.join(mom_ids) logger.info("Requesting project priorities for mom objects: %s", (str(ids_string))) - result = self.rpc.execute('get_project_priorities_for_objects', mom_ids=ids_string) + result = self._rpc.execute('get_project_priorities_for_objects', mom_ids=ids_string) logger.info("Received project priorities for %s mom objects" % (len(result))) return result @@ -244,7 +229,7 @@ class MoMQueryRPC: ids_string = ', '.join(ids) logger.info("Requesting details for %s mom objects. mom_ids: %s", len(ids), ids_string) - result = self.rpc.execute('getObjectDetails', mom_ids=ids_string) + result = self._rpc.execute('getObjectDetails', mom_ids=ids_string) logger.info("Received details for %s mom objects. mom_ids: %s", len(result), ids_string) return result @@ -252,51 +237,51 @@ class MoMQueryRPC: """get all projects :rtype dict with all projects""" logger.info("Requesting all projects") - projects = self.rpc.execute('getProjects') + projects = self._rpc.execute('getProjects') logger.info("Received %s projects", (len(projects))) return projects def getProject(self, project_mom2id): """get projects by mo2_id""" logger.info("getProject(%s)", project_mom2id) - project = self.rpc.execute('getProject', project_mom2id=project_mom2id) + project = self._rpc.execute('getProject', project_mom2id=project_mom2id) return project def getProjectTaskIds(self, project_mom2id): """get all task mom2id's for the given project :rtype dict with all projects""" logger.info("getProjectTaskIds") - task_ids = self.rpc.execute('getProjectTaskIds', project_mom2id=project_mom2id) + task_ids = self._rpc.execute('getProjectTaskIds', project_mom2id=project_mom2id) return task_ids def getPredecessorIds(self, ids): logger.debug("getSuccessorIds(%s)", ids) - result = self.rpc.execute('getPredecessorIds', mom_ids=ids) + result = self._rpc.execute('getPredecessorIds', mom_ids=ids) logger.info("getPredecessorIds(%s): %s", ids, result) return result def getSuccessorIds(self, ids): logger.debug("getSuccessorIds(%s)", ids) - result = self.rpc.execute('getSuccessorIds', mom_ids=ids) + result = self._rpc.execute('getSuccessorIds', mom_ids=ids) logger.info("getSuccessorIds(%s): %s", ids, result) return result def getTaskIdsInGroup(self, mom_group_ids): logger.debug("getTaskIdsInGroup(%s)", mom_group_ids) - result = self.rpc.execute('getTaskIdsInGroup', mom_group_ids=mom_group_ids) + result = self._rpc.execute('getTaskIdsInGroup', mom_group_ids=mom_group_ids) logger.info("getTaskIdsInGroup(%s): %s", mom_group_ids, result) return result def getTaskIdsInParentGroup(self, mom_parent_group_ids): logger.debug("getTaskIdsInParentGroup(%s)", mom_parent_group_ids) - result = self.rpc.execute('getTaskIdsInParentGroup', - mom_parent_group_ids=mom_parent_group_ids) + result = self._rpc.execute('getTaskIdsInParentGroup', + mom_parent_group_ids=mom_parent_group_ids) logger.info("getTaskIdsInParentGroup(%s): %s", mom_parent_group_ids, result) return result def getDataProducts(self, ids): logger.debug("getDataProducts(%s)", ids) - result = self.rpc.execute('getDataProducts', mom_ids=ids) + result = self._rpc.execute('getDataProducts', mom_ids=ids) logger.info('Found # dataproducts per mom2id: %s', ', '.join( '%s:%s' % (dp_id, len(dps)) for dp_id, dps in list(result.items()))) return result @@ -307,7 +292,7 @@ class MoMQueryRPC: if isinstance(otdb_ids, int) or isinstance(otdb_ids, str): otdb_ids = [otdb_ids] logger.debug("getMoMIdsForOTDBIds(%s)", otdb_ids) - result = self.rpc.execute('getMoMIdsForOTDBIds', otdb_ids=otdb_ids) + result = self._rpc.execute('getMoMIdsForOTDBIds', otdb_ids=otdb_ids) return result def getOTDBIdsForMoMIds(self, mom_ids): @@ -316,7 +301,7 @@ class MoMQueryRPC: if isinstance(mom_ids, int) or isinstance(mom_ids, str): mom_ids = [mom_ids] logger.debug("getOTDBIdsForMoMIds(%s)", mom_ids) - result = self.rpc.execute('getOTDBIdsForMoMIds', mom_ids=mom_ids) + result = self._rpc.execute('getOTDBIdsForMoMIds', mom_ids=mom_ids) return result def getTaskIdsGraph(self, mom2id): @@ -324,7 +309,7 @@ class MoMQueryRPC: returns: dict with mom2id:node as key value pairs, where each node is a dict with items node_mom2id, predecessor_ids, successor_ids""" logger.debug("getTaskIdsGraph(%s)", mom2id) - result = self.rpc.execute('getTaskIdsGraph', mom2id=mom2id) + result = self._rpc.execute('getTaskIdsGraph', mom2id=mom2id) return result def get_station_selection(self, mom_id): @@ -335,7 +320,7 @@ class MoMQueryRPC: :return: list of dict """ logger.info("Calling getStationSelection for mom id "+str(mom_id)) - station_selection = self.rpc.execute('getStationSelection', mom_id=mom_id) + station_selection = self._rpc.execute('getStationSelection', mom_id=mom_id) return station_selection def get_trigger_time_restrictions(self, mom_id): @@ -345,7 +330,7 @@ class MoMQueryRPC: :return: dict """ logger.info("Calling getTimeRestrictions for mom id "+str(mom_id)) - time_restrictions = self.rpc.execute('getTriggerTimeRestrictions', mom_id=mom_id) + time_restrictions = self._rpc.execute('getTriggerTimeRestrictions', mom_id=mom_id) return time_restrictions def get_storagemanager(self, mom_id): @@ -355,7 +340,7 @@ class MoMQueryRPC: :return: string """ logger.info("Calling GetStoragemanager for mom id "+str(mom_id)) - storagemanager = self.rpc.execute('getStoragemanager', mom_id=mom_id) + storagemanager = self._rpc.execute('getStoragemanager', mom_id=mom_id) return storagemanager diff --git a/SAS/MoM/MoMQueryService/MoMQueryServiceServer/test/t_momqueryservice.py b/SAS/MoM/MoMQueryService/MoMQueryServiceServer/test/t_momqueryservice.py index 601dadc9685..59c8cb54dae 100755 --- a/SAS/MoM/MoMQueryService/MoMQueryServiceServer/test/t_momqueryservice.py +++ b/SAS/MoM/MoMQueryService/MoMQueryServiceServer/test/t_momqueryservice.py @@ -404,7 +404,7 @@ def populate_db(mysqld): connection.close() -Mysqld = testing.mysqld.MysqldFactory(cache_initialized_db = True, on_initialized = populate_db) +Mysqld = None # testing.mysqld.MysqldFactory(cache_initialized_db = True, on_initialized = populate_db) def tearDownModule(): # clear cached database at end of tests @@ -706,21 +706,19 @@ class TestMomQueryRPC(unittest.TestCase): allocated_triggers = 10 def setUp(self): - rpc_patcher = mock.patch('lofar.mom.momqueryservice.momqueryrpc.RPC') - self.addCleanup(rpc_patcher.stop) - self.rpc_mock = rpc_patcher.start() - logger_patcher = mock.patch('lofar.mom.momqueryservice.momqueryrpc.logger') self.addCleanup(logger_patcher.stop) self.logger_mock = logger_patcher.start() - self.momrpc = MoMQueryRPC('busname') + self.rpc_mock = mock.MagicMock() + + self.momrpc = MoMQueryRPC(self.rpc_mock) def test_object_details_query(self): test_id = 1234 self.momrpc.getObjectDetails(test_id) - self.rpc_mock().execute.assert_called_with('getObjectDetails', mom_ids=str(test_id)) + self.rpc_mock.execute.assert_called_with('getObjectDetails', mom_ids=str(test_id)) def test_is_user_operator_logs_before_query(self): self.momrpc.isUserOperator(self.user_name) @@ -733,7 +731,7 @@ class TestMomQueryRPC(unittest.TestCase): self.logger_mock.info.assert_any_call("User %s is %san operator", self.user_name, '') def test_is_user_operator_logs_after_query_2(self): - self.rpc_mock().execute.return_value = {'is_operator': False} + self.rpc_mock.execute.return_value = {'is_operator': False} self.momrpc.isUserOperator(self.user_name) @@ -823,7 +821,7 @@ class TestMomQueryRPC(unittest.TestCase): def test_get_project_priority_query(self): self.momrpc.get_project_priority(self.project_name) - self.rpc_mock().execute.assert_called_with('get_project_priority', + self.rpc_mock.execute.assert_called_with('get_project_priority', project_name=self.project_name) def test_add_trigger_logs_before_query(self): @@ -843,7 +841,17 @@ class TestMomQueryRPC(unittest.TestCase): def test_add_trigger_query(self): self.momrpc.add_trigger(self.user_name, self.host_name, self.project_name, self.meta_data) - self.rpc_mock().execute.assert_called() + self.rpc_mock.execute.assert_called_with('add_trigger', user_name=self.user_name, + host_name=self.host_name, + project_name=self.project_name, + meta_data=self.meta_data) + + def test_add_trigger_query_returns_value_from_rpc(self): + self.rpc_mock.execute.return_value = 42 + + res = self.momrpc.add_trigger(self.user_name, self.host_name, self.project_name, self.meta_data) + + self.assertEqual(42, res) def test_get_triggers_logs_before_query(self): self.momrpc.get_triggers(self.user_name) @@ -852,7 +860,7 @@ class TestMomQueryRPC(unittest.TestCase): "user %s", self.user_name) def test_get_triggers_logs_after_query(self): - self.rpc_mock().execute.return_value = [{"trigger_id": 1}] + self.rpc_mock.execute.return_value = [{"trigger_id": 1}] self.momrpc.get_triggers(self.user_name) @@ -861,7 +869,7 @@ class TestMomQueryRPC(unittest.TestCase): def test_get_triggers_query(self): self.momrpc.get_triggers(self.user_name) - self.rpc_mock().execute.assert_called_with('get_triggers', user_name = self.user_name) + self.rpc_mock.execute.assert_called_with('get_triggers', user_name = self.user_name) def test_get_trigger_spec_logs_before_query(self): self.momrpc.get_trigger_spec(self.user_name, self.trigger_id) @@ -876,7 +884,7 @@ class TestMomQueryRPC(unittest.TestCase): def test_get_trigger_spec(self): self.momrpc.get_trigger_spec(self.user_name, self.trigger_id) - self.rpc_mock().execute.assert_called_with('get_trigger_spec', user_name = self.user_name, + self.rpc_mock.execute.assert_called_with('get_trigger_spec', user_name = self.user_name, trigger_id = self.trigger_id) # @mock.patch('lofar.messaging.messagebus.proton.utils.BlockingConnection') @@ -941,41 +949,41 @@ class TestMomQueryRPC(unittest.TestCase): self.momrpc.get_project_details(mom_id) - self.rpc_mock().execute.assert_called_with('get_project_details', mom_id=mom_id) + self.rpc_mock.execute.assert_called_with('get_project_details', mom_id=mom_id) def test_get_project_priorities_for_objects_query(self): self.momrpc.get_project_priorities_for_objects(self.test_id) - self.rpc_mock().execute.assert_called_with('get_project_priorities_for_objects', + self.rpc_mock.execute.assert_called_with('get_project_priorities_for_objects', mom_ids=str(self.test_id)) def test_get_time_restrictions_query(self): self.momrpc.get_trigger_time_restrictions(self.test_id) - self.rpc_mock().execute.assert_called_with('getTriggerTimeRestrictions', + self.rpc_mock.execute.assert_called_with('getTriggerTimeRestrictions', mom_id=self.test_id) def test_get_station_selection_query(self): self.momrpc.get_station_selection(self.test_id) - self.rpc_mock().execute.assert_called_with('getStationSelection', mom_id=self.test_id) + self.rpc_mock.execute.assert_called_with('getStationSelection', mom_id=self.test_id) def test_get_trigger_quota_query(self): result = self.momrpc.get_trigger_quota(self.project_name) - self.rpc_mock().execute.assert_called_with('get_trigger_quota', + self.rpc_mock.execute.assert_called_with('get_trigger_quota', project_name=self.project_name) def test_update_trigger_quota(self): self.momrpc.update_trigger_quota(self.project_name) - self.rpc_mock().execute.assert_called_with('update_trigger_quota', project_name=self.project_name) + self.rpc_mock.execute.assert_called_with('update_trigger_quota', project_name=self.project_name) def test_cancel_trigger(self): reason = 'Because I say so' self.momrpc.cancel_trigger(self.test_id, reason) - self.rpc_mock().execute.assert_called_with('cancel_trigger', trigger_id=self.test_id, + self.rpc_mock.execute.assert_called_with('cancel_trigger', trigger_id=self.test_id, reason=reason) class TestMoMDatabaseWrapper(unittest.TestCase): -- GitLab