diff --git a/src/ska_tango_base/base/task_queue_manager.py b/src/ska_tango_base/base/task_queue_manager.py index 1c56b9657e37e55bebab0b8b7309440850aba689..16ad2278e854bcf89eaa8846286453bc76bac0e4 100644 --- a/src/ska_tango_base/base/task_queue_manager.py +++ b/src/ska_tango_base/base/task_queue_manager.py @@ -295,6 +295,7 @@ class QueueManager: logger: logging.Logger, stopping_event: Event, aborting_event: Event, + suspend_event: Event, result_callback: Callable, update_command_state_callback: Callable, update_progress_callback: Callable, @@ -312,6 +313,8 @@ class QueueManager: :type stopping_event: Event :param aborting_event: Indicates whether the queue is being aborted :type aborting_event: Event + :param suspend_event: Indicates whether to suspend task retrieval + :type suspend_event: Event :param update_command_state_callback: Callback to update command state :type update_command_state_callback: Callable """ @@ -320,6 +323,7 @@ class QueueManager: self._logger = logger self.stopping_event = stopping_event self.aborting_event = aborting_event + self.suspend_event = suspend_event self._result_callback = result_callback self._update_command_state_callback = update_command_state_callback self._update_progress_callback = update_progress_callback @@ -341,6 +345,10 @@ class QueueManager: self.current_task_id = None self.current_task_progress = "" + # Don't pull new tasks off of the queue until unsuspended. + while self.suspend_event.is_set(): + time.sleep(self._queue_fetch_timeout) + if self.aborting_event.is_set(): # Drain the Queue since self.aborting_event is set while not self._work_queue.empty(): @@ -459,6 +467,7 @@ class QueueManager: self._push_change_event = push_change_event self.stopping_event = threading.Event() self.aborting_event = threading.Event() + self.suspend_event = threading.Event() self._property_update_lock = threading.Lock() self._logger = logger if logger else logging.getLogger(__name__) @@ -467,6 +476,17 @@ class QueueManager: self._task_status: Dict[str, str] = {} # unique_id, status self._threads = [] + self._long_running_properties = [ + "longRunningCommandsInQueue", + "longRunningCommandStatus", + "longRunningCommandProgress", + "longRunningCommandIDsInQueue", + "longRunningCommandResult", + ] + self._property_change_callbacks = {} + for prop in self._long_running_properties: + self._property_change_callbacks[prop] = [] + # If there's no queue, don't start threads if not self._max_queue_size: return @@ -477,6 +497,7 @@ class QueueManager: self._logger, self.stopping_event, self.aborting_event, + self.suspend_event, self.result_callback, self.update_task_state_callback, self.update_progress_callback, @@ -496,7 +517,7 @@ class QueueManager: return self._work_queue.full() @property - def task_result(self) -> Union[Tuple[str, str, str], Tuple[()]]: + def task_result(self) -> Tuple[str, str, str]: """Return the last task result. :return: Last task result @@ -649,6 +670,8 @@ class QueueManager: :param property_name: The property value :type property_name: Any """ + for callback in self._property_change_callbacks[property_name]: + callback(property_value) if self._push_change_event: self._push_change_event(property_name, property_value) @@ -664,6 +687,25 @@ class QueueManager: """Set stopping_event on each thread so it exists out. Killing the thread.""" self.stopping_event.set() + def suspend_task_dequeue(self): + """Stop pulling new tasks off the queue to execute. + + - Tasks enqueued after this method will not be dequeued. + - Existing tasks may be dequeued during this method. + """ + self.suspend_event.set() + # Wait a little longer than fetch timeout otherwise one of the worker + # threads that is waiting in `queue.get` + # will pick up a task (enqueued immediately after this method) before + # the suspend takes effect. + time.sleep(self._queue_fetch_timeout + 0.1) + self._logger.info("Queue task execution suspended") + + def unsuspend_task_dequeue(self): + """Undo suspend_task_dequeue.""" + self.suspend_event.clear() + self._logger.info("Queue task execution unsuspended") + @property def is_aborting(self) -> bool: """Return whether we are in aborting state.""" @@ -701,6 +743,29 @@ class QueueManager: return TaskState.NOT_FOUND + def add_property_change_callback(self, attribute: str, update_callback: Callable): + """Add a callback that will be executed when the attribute changes. + + :param attribute: The attribute name + :type attribute: str + :param update_callback: The function to execute + :type update_callback: Callable + """ + assert ( + attribute in self._long_running_properties + ), f"[{attribute}] is not supported, should be one of [{self._long_running_properties}]" + self._property_change_callbacks[attribute].append(update_callback) + + def remove_property_change_callback(self, update_callback: Callable): + """Remove the callback. + + :param update_callback: The function to execute + :type update_callback: Callable + """ + for callbacks in self._property_change_callbacks.values(): + if update_callback in callbacks: + callbacks.remove(update_callback) + def __len__(self) -> int: """Approximate length of the queue. diff --git a/src/ska_tango_base/utils.py b/src/ska_tango_base/utils.py index c020787c1a8730b820a54b0a65de216107c57f49..84a7b1bcfa728642065b74fee181dd47121efeff 100644 --- a/src/ska_tango_base/utils.py +++ b/src/ska_tango_base/utils.py @@ -6,6 +6,7 @@ import inspect import json import logging import pydoc +from queue import Empty, Queue import traceback import sys import uuid @@ -29,8 +30,9 @@ from tango import ( ) from tango import DevState from contextlib import contextmanager +from ska_tango_base.commands import BaseCommand from ska_tango_base.faults import GroupDefinitionsError, SKABaseError -from ska_tango_base.base.task_queue_manager import TaskResult +from ska_tango_base.base.task_queue_manager import QueueManager, TaskResult int_types = { tango._tango.CmdArgType.DevUShort, @@ -631,7 +633,7 @@ class LongRunningDeviceInterface: - Clean up """ if ev.err: - self._logger.error("Event system DevError(s) occured: %s", str(ev.errors)) + self._logger.error("Event system DevError(s) occurred: %s", str(ev.errors)) return if ev.attr_value and ev.attr_value.name == "longrunningcommandresult": @@ -666,7 +668,7 @@ class LongRunningDeviceInterface: command_name = stored_command_group[0].command_name # Trigger the callback, send command_name and command_ids - # as paramater + # as parameter self._stored_callbacks[key](command_name, command_ids) # Remove callback as the group completed @@ -684,10 +686,10 @@ class LongRunningDeviceInterface: ): """Execute the long running command with an argument if any. - Once the commmand completes, then the `on_completion_callback` + Once the command completes, then the `on_completion_callback` will be executed with the EventData as parameter. This class keeps track of the command ID and events - used to determine when this commmand has completed. + used to determine when this command has completed. :param command_name: A long running command that exists on the target Tango device. @@ -714,3 +716,70 @@ class LongRunningDeviceInterface: False, ) ) + + +class EnqueueSuspend: + """Context manager that will enqueue a command and then suspend new tasks from being taken off the queue.""" + + def __init__( + self, + queue_manager: QueueManager, + command: BaseCommand, + args: Any = None, + retries: int = 5, + timeout: float = 0.5, + ) -> None: + """Context manager to enqueue a task and suspend new tasks. + + :param queue_manager: The queue manager + :type queue_manager: QueueManager + :param command: Command to execute + :type command: BaseCommand + :param args: Argument for the command, defaults to None + :type args: Any, optional + :param retries: Number of times to retry waiting for the command to start, defaults to 5 + :type retries: int, optional + :param timeout: Time to wait for a status change, defaults to 0.5 + :type timeout: float, optional + """ + self.queue_manager = queue_manager + self.command = command + self.args = args + self.retries = retries + self.timeout = timeout + self.unique_id = "" + self.event_queue = Queue() + + def _on_status_change_callback(self, new_value): + """Add status change to queue.""" + self.event_queue.put(new_value) + + def wait_for_id(self, unique_id) -> bool: + """Wait for the enqueued task to start.""" + while True: + try: + while self.retries > 0: + value = self.event_queue.get(timeout=self.timeout) + if unique_id in value: + return True + self.retries -= 1 + except Empty: + return False + + def __enter__(self): + """Add callback enqueue task and suspend.""" + self.queue_manager.add_property_change_callback( + "longRunningCommandStatus", self._on_status_change_callback + ) + self.unique_id, _ = self.queue_manager.enqueue_task(self.command, self.args) + if not self.wait_for_id(self.unique_id): + raise Exception("Command not started") + self.queue_manager.suspend_task_dequeue() + return self.unique_id + + def __exit__(self, _exc_type, _exc_value, _exc_traceback): + """Clear callback and unsuspend.""" + self.queue_manager.remove_property_change_callback( + self._on_status_change_callback + ) + self.queue_manager.unsuspend_task_dequeue() diff --git a/tests/long_running_tasks/test_task_queue_manager.py b/tests/long_running_tasks/test_task_queue_manager.py index bd26f59d5bb430b23ab38e0ec95754746c60a100..e4ec769ef892ac1ff3c6a8c6e56741d58c41ff02 100644 --- a/tests/long_running_tasks/test_task_queue_manager.py +++ b/tests/long_running_tasks/test_task_queue_manager.py @@ -12,6 +12,7 @@ from ska_tango_base.base.task_queue_manager import ( ) from ska_tango_base.base.reference_component_manager import QueueWorkerComponentManager from ska_tango_base.commands import BaseCommand +from ska_tango_base.utils import EnqueueSuspend logger = logging.getLogger(__name__) @@ -108,6 +109,20 @@ def simple_task(): return get_task +@pytest.fixture +def noop_task(): + """Fixture for a very simple task.""" + + def get_task(): + class NoopTask(BaseCommand): + def do(self): + return True + + return NoopTask(target=None) + + return get_task + + @pytest.fixture def abort_task(): """Fixture for a task that aborts.""" @@ -524,6 +539,37 @@ class TestComponentManager: ) assert cm.task_ids_in_queue == () + def test_enqueue_suspend_util(self, simple_task, noop_task, caplog): + """Test that EnqueueSuspend enqueues and suspends.""" + caplog.set_level(logging.INFO) + cm = QueueWorkerComponentManager( + op_state_model=None, + logger=logger, + max_queue_size=3, + num_workers=3, + push_change_event=None, + child_devices=[], + ) + second_task_id = "" + with EnqueueSuspend(cm._queue_manager, simple_task(), args=1) as unique_id: + # Make sure the command exists + assert unique_id + # Make sure pulling new tasks off the queue is paused + assert cm._queue_manager.suspend_event.is_set() + # Make sure any new work is not pulled off queue + second_task_id, _ = cm.enqueue(noop_task()) + for wait_time in [0.1, 0.2, 0.3]: + assert ( + second_task_id in cm.task_ids_in_queue + ), "Task should not be taken off the queue while suspend is in effect" + time.sleep(wait_time) + + # Make sure work is unsuspended + assert not cm._queue_manager.suspend_event.is_set() + # Make sure the enqueued task during suspension is taken off the queue + time.sleep(0.5) + assert second_task_id not in cm.task_ids_in_queue + @pytest.mark.forked class TestStress: