diff --git a/src/ska_tango_base/base/component_manager.py b/src/ska_tango_base/base/component_manager.py index ce3eacd0f98bbe332440c6ffd95ae0ea7716e32d..ea7a6a32e2eab1879dddb00296cf1c9862cd2c11 100644 --- a/src/ska_tango_base/base/component_manager.py +++ b/src/ska_tango_base/base/component_manager.py @@ -23,7 +23,10 @@ The basic model is: the component to change behaviour and/or state; and it *monitors* its component by keeping track of its state. """ -from typing import Any, Optional, Tuple +import logging +from typing import Any, Callable, Optional, Tuple +from ska_tango_base.base.op_state_model import OpStateModel + from ska_tango_base.commands import BaseCommand, ResultCode from ska_tango_base.control_model import PowerMode @@ -45,13 +48,7 @@ class BaseComponentManager: or on """ - def __init__( - self, - op_state_model, - *args, - queue_manager: Optional[QueueManager] = None, - **kwargs - ): + def __init__(self, op_state_model, *args, **kwargs): """ Initialise a new ComponentManager instance. @@ -61,7 +58,7 @@ class BaseComponentManager: In this case any tasks enqueued to it will block. """ self.op_state_model = op_state_model - self.queue_manager = queue_manager if queue_manager else QueueManager() + self.queue_manager = self.create_queue_manager() def start_communicating(self): """ @@ -216,6 +213,17 @@ class BaseComponentManager: """ self.op_state_model.perform_action("component_fault") + def create_queue_manager(self) -> QueueManager: + """Create a QueueManager. + + By default the QueueManager will not have a queue or workers. Thus + tasks enqueued will block. + + :return: The queue manager. + :rtype: QueueManager + """ + return QueueManager(max_queue_size=0, num_workers=0) + def enqueue( self, task: BaseCommand, @@ -225,7 +233,57 @@ class BaseComponentManager: :param task: The task to execute in the thread :type task: BaseCommand - :return: The unique ID of the queued command - :rtype: str + :param argin: The parameter for the command + :type argin: Any + :return: The unique ID of the queued command and the ResultCode + :rtype: tuple """ return self.queue_manager.enqueue_task(task, argin=argin) + + +class QueueWorkerComponentManager(BaseComponentManager): + """A component manager that configres the queue manager.""" + + def __init__( + self, + op_state_model: Optional[OpStateModel], + logger: logging.Logger, + max_queue_size: int, + num_workers: int, + push_change_event: Optional[Callable], + *args, + **kwargs + ): + """Component manager that configures the queue. + + :param op_state_model: The ops state model + :type op_state_model: OpStateModel + :param logger: Logger to use + :type logger: logging.Logger + :param max_queue_size: The size of the queue + :type max_queue_size: int + :param num_workers: The number of workers + :type num_workers: int + :param push_change_event: A method that will be called when attributes are updated + :type push_change_event: Callable + """ + self.logger = logger + self.max_queue_size = max_queue_size + self.num_workers = num_workers + self.push_change_event = push_change_event + super().__init__(op_state_model, *args, **kwargs) + + def create_queue_manager(self) -> QueueManager: + """Create a QueueManager. + + Create the QueueManager with the queue configured as needed. + + :return: The queue manager + :rtype: QueueManager + """ + return QueueManager( + max_queue_size=self.max_queue_size, + num_workers=self.num_workers, + logger=self.logger, + push_change_event=self.push_change_event, + ) diff --git a/src/ska_tango_base/base/task_queue_manager.py b/src/ska_tango_base/base/task_queue_manager.py index f72b8730bd94bc3c54378ad9d0a62889f327fbb0..e7efd6208b5760fa342bd20ffa9e05e65590ba13 100644 --- a/src/ska_tango_base/base/task_queue_manager.py +++ b/src/ska_tango_base/base/task_queue_manager.py @@ -345,6 +345,8 @@ class QueueManager: :param task: Task to execute :type task: BaseCommand + :param argin: The argument for the command + :type argin: Any :param unique_id: The task unique ID :type unique_id: str :return: The result of the task @@ -494,7 +496,7 @@ class QueueManager: @property def task_progress( self, - ) -> Tuple[str,]: # noqa: E231 + ) -> Tuple[Optional[str],]: # noqa: E231 """Return the task progress. :return: The task progress pairs (id, progress) @@ -514,6 +516,8 @@ class QueueManager: :param task: The task to execute in a thread :type task: BaseCommand + :param argin: The parameter for the command + :type argin: Any :return: The unique ID of the command :rtype: string """ diff --git a/src/ska_tango_base/commands.py b/src/ska_tango_base/commands.py index 7ebf3f582072ff211eafd94514c64cdf13ca515a..ff14c111acd3c128799c5bf62a68edad170cb015 100644 --- a/src/ska_tango_base/commands.py +++ b/src/ska_tango_base/commands.py @@ -381,7 +381,7 @@ class ResponseCommand(BaseCommand): f"Exiting command {self.name} with return_code " f"{return_code!s}, message: '{message}'.", ) - return (return_code, message) + return return_code, message class CompletionCommand(StateModelCommand): diff --git a/tests/long_running_tasks/reference_base_device.py b/tests/long_running_tasks/reference_base_device.py index 97c843f0eccc5fb438ae6de690ef5d0ae5ff06d7..cf8882e3d0094637afb61661ffbbb2b4b125c12c 100644 --- a/tests/long_running_tasks/reference_base_device.py +++ b/tests/long_running_tasks/reference_base_device.py @@ -13,9 +13,11 @@ import time from tango.server import command from tango import DebugIt -from ska_tango_base.base.component_manager import BaseComponentManager +from ska_tango_base.base.component_manager import ( + QueueWorkerComponentManager, +) from ska_tango_base.base.base_device import SKABaseDevice -from ska_tango_base.base.task_queue_manager import QueueManager, ResultCode +from ska_tango_base.base.task_queue_manager import ResultCode from ska_tango_base.commands import ResponseCommand @@ -91,7 +93,7 @@ class LongRunningCommandBaseTestDevice(SKABaseDevice): "In NonAbortingTask repeating %s", retries, ) - return (ResultCode.OK, "Done") + return ResultCode.OK, "Done" @command( dtype_in=float, @@ -188,13 +190,10 @@ class AsyncBaseDevice(LongRunningCommandBaseTestDevice): def create_component_manager(self: SKABaseDevice): """Create the component manager with a queue manager that has workers.""" - queue_manager = QueueManager( - max_queue_size=10, - num_workers=3, + return QueueWorkerComponentManager( + op_state_model=self.op_state_model, logger=self.logger, + max_queue_size=20, + num_workers=3, push_change_event=self.push_change_event, ) - return BaseComponentManager( - op_state_model=self.op_state_model, - queue_manager=queue_manager, - ) diff --git a/tests/long_running_tasks/test_reference_base_device.py b/tests/long_running_tasks/test_reference_base_device.py index 238a8666590d2941e20311a4c733f3feb638d629..8e31122266b7671d2d6e238546b24cb685a84c49 100644 --- a/tests/long_running_tasks/test_reference_base_device.py +++ b/tests/long_running_tasks/test_reference_base_device.py @@ -157,8 +157,8 @@ def test_callbacks(): assert len(attribute_values[10]) == 3 tr = TaskResult.from_task_result(attribute_values[10]) assert tr.get_task_unique_id().id_task_name == "TestProgressCommand" - tr.result_code == ResultCode.OK - tr.task_result == "None" + assert tr.result_code == ResultCode.OK + assert tr.task_result == "OK" @pytest.mark.forked diff --git a/tests/long_running_tasks/test_task_queue_manager.py b/tests/long_running_tasks/test_task_queue_manager.py index a9d0f2788ad32ad5edeffd575a0263d395d653e5..b1c94d7f3ab06a62e035883063d458d39e622a16 100644 --- a/tests/long_running_tasks/test_task_queue_manager.py +++ b/tests/long_running_tasks/test_task_queue_manager.py @@ -10,24 +10,12 @@ from ska_tango_base.base.task_queue_manager import ( TaskResult, TaskState, ) -from ska_tango_base.base.component_manager import BaseComponentManager +from ska_tango_base.base.component_manager import QueueWorkerComponentManager from ska_tango_base.commands import BaseCommand logger = logging.getLogger(__name__) -def check_matching_pattern(list_to_check=()): - """Check that lengths go 1,2,3,2,1 for example.""" - list_to_check = list(list_to_check) - if not list_to_check[-1]: - list_to_check.pop() - assert len(list_to_check) > 2 - while len(list_to_check) > 2: - last_e = list_to_check.pop() - first_e = list_to_check.pop(0) - assert len(last_e) == len(first_e) - - @pytest.fixture def progress_task(): """Fixture for a test that throws an exception.""" @@ -310,8 +298,8 @@ class TestQueueManagerTasks: ] task_result_ids = [res[0] for res in task_result] - check_matching_pattern(tuple(tasks_in_queue)) - check_matching_pattern(tuple(task_ids_in_queue)) + assert len(tasks_in_queue) == 8 + assert len(task_ids_in_queue) == 8 # Since there's 3 workers there should at least once be 3 in progress for status in task_status: @@ -362,26 +350,27 @@ class TestQueueManagerExit: @pytest.mark.timeout(15) def test_exit_abort(self, abort_task, slow_task): """Test aborting exit.""" - qm = QueueManager( + cm = QueueWorkerComponentManager( + op_state_model=None, + logger=logger, max_queue_size=10, num_workers=2, - logger=logger, + push_change_event=None, ) - cm = BaseComponentManager(op_state_model=None, queue_manager=qm, logger=None) cm.enqueue(abort_task(), 0.1) # Wait for the command to start - while not qm.task_status: + while not cm.task_status: pass # Start aborting cm.queue_manager.abort_tasks() # Wait for the exit - while not qm.task_result: + while not cm.task_result: pass # aborting state should be cleaned up since the queue is empty and # nothing is in progress - while qm.is_aborting: + while cm.queue_manager.is_aborting: pass # When aborting this should be rejected @@ -397,54 +386,56 @@ class TestQueueManagerExit: unique_id, _ = cm.enqueue(slow_task()) while True: - tr = TaskResult.from_task_result(qm.task_result) + tr = TaskResult.from_task_result(cm.task_result) if tr.unique_id == unique_id and tr.result_code == ResultCode.ABORTED: break time.sleep(0.1) # Resume the commands - qm.resume_tasks() - assert not qm.is_aborting + cm.queue_manager.resume_tasks() + assert not cm.queue_manager.is_aborting # Wait for my slow command to finish unique_id, _ = cm.enqueue(slow_task()) while True: - tr = TaskResult.from_task_result(qm.task_result) + tr = TaskResult.from_task_result(cm.task_result) if tr.unique_id == unique_id: break @pytest.mark.timeout(20) def test_exit_stop(self, stop_task): """Test stopping exit.""" - qm = QueueManager( + cm = QueueWorkerComponentManager( + op_state_model=None, + logger=logger, max_queue_size=5, num_workers=2, - logger=logger, + push_change_event=None, ) - cm = BaseComponentManager(op_state_model=None, queue_manager=qm, logger=None) cm.enqueue(stop_task()) # Wait for the command to start - while not qm.task_status: + while not cm.task_status: pass # Stop all threads cm.queue_manager.stop_tasks() # Wait for the exit - while not qm.task_result: + while not cm.task_result: pass # Wait for all the workers to stop - while not any([worker.is_alive() for worker in qm._threads]): + while not any([worker.is_alive() for worker in cm.queue_manager._threads]): pass @pytest.mark.timeout(5) def test_delete_queue(self, slow_task, stop_task, abort_task): """Test deleting the queue.""" - qm = QueueManager( + cm = QueueWorkerComponentManager( + op_state_model=None, + logger=logger, max_queue_size=8, num_workers=2, - logger=logger, + push_change_event=None, ) - cm = BaseComponentManager(op_state_model=None, queue_manager=qm, logger=None) cm.enqueue(slow_task()) cm.enqueue(stop_task()) cm.enqueue(abort_task()) @@ -465,18 +456,23 @@ class TestComponentManager: def test_init(self): """Test that we can init the component manager.""" - qm = QueueManager(max_queue_size=0, num_workers=1, logger=logger) - cm = BaseComponentManager(op_state_model=None, queue_manager=qm, logger=logger) - assert cm.queue_manager.task_ids_in_queue == () + cm = QueueWorkerComponentManager( + op_state_model=None, + logger=logger, + max_queue_size=0, + num_workers=1, + push_change_event=None, + ) + assert cm.task_ids_in_queue == () @pytest.mark.forked class TestStress: - """Stress test the queue mananger.""" + """Stress test the queue manager.""" @pytest.mark.timeout(20) def test_stress(self, slow_task): - """Stress test the queue mananger.""" + """Stress test the queue manager.""" qm = QueueManager(max_queue_size=100, num_workers=50, logger=logger) assert len(qm._threads) == 50 for worker in qm._threads: