diff --git a/src/ska_tango_base/base/task_queue_manager.py b/src/ska_tango_base/base/task_queue_manager.py index 1c56b9657e37e55bebab0b8b7309440850aba689..dc383f98733c27fa3ab51bb0f145c1de43fc4273 100644 --- a/src/ska_tango_base/base/task_queue_manager.py +++ b/src/ska_tango_base/base/task_queue_manager.py @@ -292,7 +292,7 @@ class QueueManager: def __init__( self: QueueManager.Worker, queue: Queue, - logger: logging.Logger, + log_message: Callable, stopping_event: Event, aborting_event: Event, result_callback: Callable, @@ -317,7 +317,7 @@ class QueueManager: """ super().__init__() self._work_queue = queue - self._logger = logger + self._log_message = log_message self.stopping_event = stopping_event self.aborting_event = aborting_event self._result_callback = result_callback @@ -346,7 +346,9 @@ class QueueManager: while not self._work_queue.empty(): unique_id, _, _ = self._work_queue.get() self.current_task_id = unique_id - self._logger.warning("Aborting task ID [%s]", unique_id) + self._log_message( + f"Aborting task ID [{unique_id}]", level="WARNING" + ) result = TaskResult( ResultCode.ABORTED, f"{unique_id} Aborted", unique_id ) @@ -474,7 +476,7 @@ class QueueManager: self._threads = [ self.Worker( self._work_queue, - self._logger, + self._log_message, self.stopping_event, self.aborting_event, self.result_callback, @@ -701,6 +703,18 @@ class QueueManager: return TaskState.NOT_FOUND + def _log_message(self, message: str, level: str = "INFO"): + """Log a message. + + Called from worker threads as well. + + :param message: Message to log + :type message: str + :param level: A valid logging level, defaults to "INFO" + :type level: str, optional + """ + self._logger.log(getattr(logging, level), message) + def __len__(self) -> int: """Approximate length of the queue. diff --git a/tests/long_running_tasks/test_task_queue_manager.py b/tests/long_running_tasks/test_task_queue_manager.py index bd26f59d5bb430b23ab38e0ec95754746c60a100..cce3440f1afeae24336207a9ab28c8a41056c047 100644 --- a/tests/long_running_tasks/test_task_queue_manager.py +++ b/tests/long_running_tasks/test_task_queue_manager.py @@ -12,6 +12,7 @@ from ska_tango_base.base.task_queue_manager import ( ) from ska_tango_base.base.reference_component_manager import QueueWorkerComponentManager from ska_tango_base.commands import BaseCommand +from tests.test_utils import LRCAttributesStore logger = logging.getLogger(__name__) @@ -382,40 +383,29 @@ class TestQueueManagerExit: @pytest.mark.forked @pytest.mark.timeout(5) - def test_exit_abort(self, abort_task, slow_task): + def test_exit_abort(self, abort_task, slow_task, caplog): """Test aborting exit.""" - results = [] - - def catch_updates(name, result): - if name == "longRunningCommandResult": - tr = TaskResult.from_task_result(result) - results.append( - ( - tr.unique_id, - tr.result_code, - ) - ) + attribute_store = LRCAttributesStore() + caplog.set_level(logging.INFO) cm = QueueWorkerComponentManager( op_state_model=None, logger=logger, max_queue_size=10, num_workers=2, - push_change_event=catch_updates, + push_change_event=attribute_store.store_push_event, child_devices=[], ) cm.enqueue(abort_task(), 0.1) # Wait for the command to start - while not cm.task_status: - time.sleep(0.1) + attribute_store.get_attribute_value("longRunningCommandStatus") # Start aborting cm._queue_manager.abort_tasks() # Wait for the exit - while not cm.task_result: - time.sleep(0.1) + attribute_store.get_attribute_value("longRunningCommandResult") # aborting state should be cleaned up since the queue is empty and # nothing is in progress while cm._queue_manager.is_aborting: @@ -433,14 +423,16 @@ class TestQueueManagerExit: assert cm._queue_manager.is_aborting # Load up some tasks that should be aborted - cm.enqueue(slow_task()) + aborted_task_id, _ = cm.enqueue(slow_task()) cm.enqueue(slow_task()) unique_id, _ = cm.enqueue(slow_task()) while True: - if (unique_id, ResultCode.ABORTED) in results: + result_id, result_code, _ = attribute_store.get_attribute_value( + "longRunningCommandResult" + ) + if (unique_id, ResultCode.ABORTED) == (result_id, int(result_code)): break - time.sleep(0.1) # Resume the commands cm._queue_manager.resume_tasks() @@ -450,9 +442,14 @@ class TestQueueManagerExit: unique_id, _ = cm.enqueue(slow_task()) while True: - if (unique_id, ResultCode.OK) in results: + result_id, result_code, _ = attribute_store.get_attribute_value( + "longRunningCommandResult" + ) + if (unique_id, ResultCode.OK) == (result_id, int(result_code)): break - time.sleep(0.1) + + log_messages = [rec.msg for rec in caplog.records] + assert f"Aborting task ID [{aborted_task_id}]" in log_messages @pytest.mark.forked @pytest.mark.timeout(5) diff --git a/tests/test_utils.py b/tests/test_utils.py index a27e644172833edef179dd820e8cf541fa61ba3d..f75b7cfbabdee5bbe4a957de3191645f4b3eb04a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,8 @@ """Tests for skabase.utils.""" from contextlib import nullcontext +from queue import Queue import json +from typing import Any import pytest from ska_tango_base.utils import ( @@ -262,3 +264,42 @@ def test_for_testing_only_decorator(): with pytest.warns(None) as warning_record: assert bah() == "bah" assert len(warning_record) == 0 # no warning was raised because we are testing + + +class LRCAttributesStore: + """Utility class to keep track of long running command attribute changes.""" + + def __init__(self) -> None: + """Create the queues.""" + self.queues = {} + for attribute in [ + "longRunningCommandsInQueue", + "longRunningCommandStatus", + "longRunningCommandProgress", + "longRunningCommandIDsInQueue", + "longRunningCommandResult", + ]: + self.queues[attribute] = Queue() + + def store_push_event(self, attribute_name: str, value: Any): + """Store attribute changes as they change. + + :param attribute_name: a valid LCR attribute + :type attribute_name: str + :param value: The value of the attribute + :type value: Any + """ + assert attribute_name in self.queues + self.queues[attribute_name].put_nowait(value) + + def get_attribute_value(self, attribute_name: str, fetch_timeout: float = 2.0): + """Read a value from the queue. + + :param attribute_name: a valid LCR attribute + :type attribute_name: str + :param fetch_timeout: How long to wait for a event, defaults to 2.0 + :type fetch_timeout: float, optional + :return: An attribute value fromthe queue + :rtype: Any + """ + return self.queues[attribute_name].get(timeout=fetch_timeout)