Skip to content
Snippets Groups Projects
Unverified Commit b2bc42f7 authored by SKAJohanVenter's avatar SKAJohanVenter
Browse files

SAR-276 Now enqueing Command rather than QueueTask

parent 83d947db
Branches
No related tags found
No related merge requests found
......@@ -36,7 +36,6 @@ import debugpy
import ska_ser_logging
from ska_tango_base import release
from ska_tango_base.base import AdminModeModel, OpStateModel, BaseComponentManager
from ska_tango_base.base.task_queue_manager import QueueManager, QueueTask
from ska_tango_base.commands import (
BaseCommand,
CompletionCommand,
......@@ -836,8 +835,7 @@ class SKABaseDevice(Device):
def create_component_manager(self):
"""Create and return a component manager for this device."""
queue_manager = QueueManager()
return BaseComponentManager(self.op_state_model, queue_manager)
return BaseComponentManager(self.op_state_model)
def register_command_object(self, command_name, command_object):
"""
......@@ -1230,17 +1228,11 @@ class SKABaseDevice(Device):
information purpose only.
:rtype: (ResultCode, str)
"""
class ResetTask(QueueTask):
def do(self):
self.args[0].reset()
self.target.reset()
message = "Reset command completed OK"
self.logger.info(message)
return (ResultCode.OK, message)
unique_id = self.target.enqueue(ResetTask(self.target, logger=self.logger))
return ResultCode.OK, unique_id
def is_Reset_allowed(self):
"""
Whether the ``Reset()`` command is allowed to be run in the current state.
......@@ -1270,8 +1262,9 @@ class SKABaseDevice(Device):
:rtype: (ResultCode, str)
"""
command = self.get_command_object("Reset")
(return_code, message) = command()
return [[return_code], [message]]
unique_id, return_code = self.component_manager.enqueue(command)
return [[return_code], [unique_id]]
class StandbyCommand(StateModelCommand, ResponseCommand):
"""A class for the SKABaseDevice's Standby() command."""
......@@ -1303,19 +1296,11 @@ class SKABaseDevice(Device):
information purpose only.
:rtype: (ResultCode, str)
"""
class StandByTask(QueueTask):
def do(self):
self.args[0].standby()
self.target.standby()
message = "Standby command completed OK"
self.logger.info(message)
return (ResultCode.OK, message)
unique_id = self.target.enqueue(
StandByTask(self.target, logger=self.logger)
)
return ResultCode.OK, unique_id
def is_Standby_allowed(self):
"""
Check if command Standby is allowed in the current device state.
......@@ -1346,8 +1331,9 @@ class SKABaseDevice(Device):
:rtype: (ResultCode, str)
"""
command = self.get_command_object("Standby")
(return_code, message) = command()
return [[return_code], [message]]
unique_id, return_code = self.component_manager.enqueue(command)
return [[return_code], [unique_id]]
class OffCommand(StateModelCommand, ResponseCommand):
"""A class for the SKABaseDevice's Off() command."""
......@@ -1379,17 +1365,11 @@ class SKABaseDevice(Device):
information purpose only.
:rtype: (ResultCode, str)
"""
class OffTask(QueueTask):
def do(self):
self.args[0].off()
self.target.off()
message = "Off command completed OK"
self.logger.info(message)
return (ResultCode.OK, message)
unique_id = self.target.enqueue(OffTask(self.target, logger=self.logger))
return ResultCode.OK, unique_id
def is_Off_allowed(self):
"""
Check if command `Off` is allowed in the current device state.
......@@ -1420,8 +1400,9 @@ class SKABaseDevice(Device):
:rtype: (ResultCode, str)
"""
command = self.get_command_object("Off")
(return_code, message) = command()
return [[return_code], [message]]
unique_id, return_code = self.component_manager.enqueue(command)
return [[return_code], [unique_id]]
class OnCommand(StateModelCommand, ResponseCommand):
"""A class for the SKABaseDevice's On() command."""
......@@ -1453,17 +1434,11 @@ class SKABaseDevice(Device):
information purpose only.
:rtype: (ResultCode, str)
"""
class OnTask(QueueTask):
def do(self):
self.args[0].on()
self.target.on()
message = "On command completed OK"
self.logger.info(message)
return (ResultCode.OK, message)
unique_id = self.target.enqueue(OnTask(self.target, logger=self.logger))
return ResultCode.OK, unique_id
def is_On_allowed(self):
"""
Check if command `On` is allowed in the current device state.
......@@ -1495,8 +1470,9 @@ class SKABaseDevice(Device):
:rtype: (ResultCode, str)
"""
command = self.get_command_object("On")
(return_code, message) = command()
return [[return_code], [message]]
unique_id, return_code = self.component_manager.enqueue(command)
return [[return_code], [unique_id]]
class AbortCommandsCommand(ResponseCommand):
"""The command class for the AbortCommand command."""
......
......@@ -23,10 +23,11 @@ The basic model is:
the component to change behaviour and/or state; and it *monitors* its
component by keeping track of its state.
"""
from typing import Optional
from typing import Any, Optional, Tuple
from ska_tango_base.commands import BaseCommand, ResultCode
from ska_tango_base.control_model import PowerMode
from ska_tango_base.base.task_queue_manager import QueueManager, QueueTask
from ska_tango_base.base.task_queue_manager import QueueManager
class BaseComponentManager:
......@@ -170,13 +171,14 @@ class BaseComponentManager:
def enqueue(
self,
task: QueueTask,
) -> str:
task: BaseCommand,
argin: Optional[Any] = None,
) -> Tuple[str, ResultCode]:
"""Put `task` on the queue. The unique ID for it is returned.
:param task: The task to execute in the thread
:type task: QueueTask
:type task: BaseCommand
:return: The unique ID of the queued command
:rtype: str
"""
return self.queue_manager.enqueue_task(task)
return self.queue_manager.enqueue_task(task, argin=argin)
......@@ -41,70 +41,6 @@ be made available as a Tango device attribute named `command_result`. It will be
tr.to_task_result()
('UniqueID', '0', 'The task result')
*********
QueueTask
*********
This class should be subclassed and the `do` method implemented with the required functionality.
The `do` method will be executed by the background worker in a thread.
`get_task_name` can be overridden if you want to change the name of the task as it would appear in
the `tasks_in_queue` property.
Simple example:
.. code-block:: py
class SimpleTask(QueueTask):
def do(self):
num_one = self.args[0]
num_two = self.kwargs.get("num_two")
return num_one + num_two
return SimpleTask(2, num_two=3)
3 items are added dynamically by the worker thread and is available for use in the class instance.
* **aborting_event**: can be check periodically to determine whether
the queue tasks have been aborted to gracefully complete the task in progress.
The thread will stay active and once `aborting_event` has been unset,
new tasks will be fetched from the queue for execution.
.. code-block:: py
class AbortTask(QueueTask):
def do(self):
sleep_time = self.args[0]
while not self.aborting_event.is_set():
time.sleep(sleep_time)
return AbortTask(0.2)
* **stopping_event**: can be check periodically to determine whether
the queue tasks have been stopped. In this case the thread will complete.
.. code-block:: py
class StopTask(QueueTask):
def do(self):
assert not self.stopping_event.is_set()
while not self.stopping_event.is_set():
pass
return StopTask()
* **update_progress**: a callback that can be called wth the current progress
of the task in progress
.. code-block:: py
class ProgressTask(QueueTask):
def do(self):
for i in range(100):
self.update_progress(str(i))
time.sleep(0.5)
return ProgressTask()
************
QueueManager
......@@ -185,7 +121,7 @@ from dataclasses import dataclass
import tango
from ska_tango_base.commands import ResultCode
from ska_tango_base.commands import BaseCommand, ResultCode
class TaskState(enum.IntEnum):
......@@ -306,63 +242,6 @@ class TaskResult:
return TaskUniqueId.from_unique_id(self.unique_id)
class QueueTask:
"""A task that can be put on the queue."""
def __init__(self: QueueTask, *args, logger=Optional[None], **kwargs) -> None:
"""Create the task. args and kwargs are stored and should be referenced in the `do` method."""
self.logger = logger if logger else logging.getLogger(__name__)
self.args = args
self.kwargs = kwargs
self._update_progress_callback = None
@property
def aborting_event(self) -> threading.Event:
"""Worker adds aborting_event threading event.
Indicates whether task execution have been aborted.
:return: The aborting_event event.
:rtype: threading.Event
"""
return self.kwargs.get("aborting_event")
@property
def stopping_event(self) -> threading.Event:
"""Worker adds stopping_event threading event.
Indicates whether task execution have been stopped.
:return: The stopping_event.
:rtype: threading.Event
"""
return self.kwargs.get("stopping_event")
def update_progress(self, progress: str):
"""Call the callback to update the progress.
:param progress: String that to indicate progress of task
:type progress: str
"""
self._update_progress_callback = self.kwargs.get(
"update_task_progress_callback"
)
if self._update_progress_callback:
self._update_progress_callback(progress)
def get_task_name(self) -> str:
"""Return a custom task name.
:return: The name of the task
:rtype: str
"""
return self.__class__.__name__
def do(self: QueueTask) -> Any:
"""Implement this method with your functionality."""
raise NotImplementedError
class QueueManager:
"""Manages the worker threads. Updates the properties as the tasks are completed."""
......@@ -424,7 +303,7 @@ class QueueManager:
if self.aborting_event.is_set():
# Drain the Queue since self.aborting_event is set
while not self._work_queue.empty():
unique_id, _ = self._work_queue.get()
unique_id, _, _ = self._work_queue.get()
self.current_task_id = unique_id
self._logger.warning("Aborting task ID [%s]", unique_id)
result = TaskResult(
......@@ -435,16 +314,14 @@ class QueueManager:
time.sleep(self._queue_fetch_timeout)
continue # Don't try and get work off the queue below, continue next loop
try:
(unique_id, task) = self._work_queue.get(
(unique_id, task, argin) = self._work_queue.get(
block=True, timeout=self._queue_fetch_timeout
)
self._update_command_state_callback(unique_id, "IN_PROGRESS")
self.current_task_id = unique_id
task.kwargs[
"update_task_progress_callback"
] = self._update_task_progress
result = self.execute_task(task, unique_id)
setattr(task, "update_progress", self._update_task_progress)
result = self.execute_task(task, argin, unique_id)
self._result_callback(result)
self._work_queue.task_done()
except Empty:
......@@ -461,17 +338,27 @@ class QueueManager:
self._update_progress_callback()
@classmethod
def execute_task(cls, task: QueueTask, unique_id: str) -> TaskResult:
def execute_task(
cls, task: BaseCommand, argin: Any, unique_id: str
) -> TaskResult:
"""Execute a task, return results in a standardised format.
:param task: Task to execute
:type task: QueueTask
:type task: BaseCommand
:param unique_id: The task unique ID
:type unique_id: str
:return: The result of the task
:rtype: TaskResult
"""
try:
if hasattr(task, "is_allowed"):
if not task.is_allowed():
return TaskResult(
ResultCode.NOT_ALLOWED, "Command not allowed", unique_id
)
if argin:
result = task.do(argin)
else:
result = task.do()
# If the response is (ResultCode, Any)
if (
......@@ -620,39 +507,45 @@ class QueueManager:
progress.append(worker.current_task_progress)
return tuple(progress)
def enqueue_task(self, task: QueueTask) -> str:
def enqueue_task(
self, task: BaseCommand, argin: Optional[Any] = None
) -> Tuple[str, ResultCode]:
"""Add the task to be done onto the queue.
:param task: The task to execute in a thread
:type task: QueueTask
:type task: BaseCommand
:return: The unique ID of the command
:rtype: string
"""
unique_id = self.generate_unique_id(task.get_task_name())
unique_id = self.generate_unique_id(task.__class__.__name__)
# Inject the events into the task
task.kwargs["aborting_event"] = self.aborting_event
task.kwargs["stopping_event"] = self.stopping_event
setattr(task, "aborting_event", self.aborting_event)
setattr(task, "stopping_event", self.stopping_event)
# If there is no queue, just execute the command and return
if self._max_queue_size == 0:
self.update_task_state_callback(unique_id, "IN_PROGRESS")
result = self.Worker.execute_task(task, unique_id)
# This task blocks, so no need to update progress
setattr(task, "update_progress", lambda x: None)
result = self.Worker.execute_task(task, argin, unique_id)
self.result_callback(result)
return unique_id
return unique_id, result.result_code
if self.queue_full:
self.result_callback(
TaskResult(ResultCode.REJECTED, "Queue is full", unique_id)
)
return unique_id
return unique_id, ResultCode.REJECTED
self._work_queue.put([unique_id, task])
self._work_queue.put([unique_id, task, argin])
with self._property_update_lock:
self._tasks_in_queue[unique_id] = task.get_task_name()
self._tasks_in_queue[unique_id] = task.__class__.__name__
self._on_property_change("longRunningCommandsInQueue", self.tasks_in_queue)
self._on_property_change("longRunningCommandIDsInQueue", self.task_ids_in_queue)
return unique_id
return unique_id, ResultCode.QUEUED
def result_callback(self, task_result: TaskResult):
"""Run when the task, taken from the queue, have completed to update the appropriate attributes.
......
......@@ -15,7 +15,7 @@ from tango import DebugIt
from ska_tango_base.base.component_manager import BaseComponentManager
from ska_tango_base.base.base_device import SKABaseDevice
from ska_tango_base.base.task_queue_manager import QueueManager, ResultCode, QueueTask
from ska_tango_base.base.task_queue_manager import QueueManager, ResultCode
from ska_tango_base.commands import ResponseCommand
......@@ -57,16 +57,9 @@ class LongRunningCommandBaseTestDevice(SKABaseDevice):
def do(self, argin):
"""Do command."""
class SimpleTask(QueueTask):
def do(self):
num_one = self.args[0]
return num_one + 2
self.logger.info("In ShortCommand")
unique_id = self.target.enqueue(SimpleTask(2))
return ResultCode.OK, unique_id
result = argin + 2
return ResultCode.OK, result
@command(
dtype_in=int,
......@@ -76,7 +69,7 @@ class LongRunningCommandBaseTestDevice(SKABaseDevice):
def Short(self, argin):
"""Short command."""
handler = self.get_command_object("Short")
(return_code, message) = handler(argin)
(return_code, message) = self.component_manager.enqueue(handler, argin=argin)
return f"{return_code}", f"{message}"
class NonAbortingLongRunningCommand(ResponseCommand):
......@@ -90,25 +83,15 @@ class LongRunningCommandBaseTestDevice(SKABaseDevice):
See the implementation of AnotherLongRunningCommand.
"""
class NonAbortingTask(QueueTask):
"""NonAbortingTask."""
def do(self):
"""NonAborting."""
retries = 45
while retries > 0:
retries -= 1
time.sleep(self.args[0]) # This command takes long
time.sleep(argin) # This command takes long
self.logger.info(
"In NonAbortingTask repeating %s",
retries,
)
self.logger.info("In NonAbortingTask")
unique_id = self.target.enqueue(NonAbortingTask(argin, logger=self.logger))
return ResultCode.OK, unique_id
return (ResultCode.OK, "Done")
@command(
dtype_in=float,
......@@ -118,7 +101,7 @@ class LongRunningCommandBaseTestDevice(SKABaseDevice):
def NonAbortingLongRunning(self, argin):
"""Non AbortingLongRunning command."""
handler = self.get_command_object("NonAbortingLongRunning")
(return_code, message) = handler(argin)
(return_code, message) = self.component_manager.enqueue(handler, argin)
return f"{return_code}", f"{message}"
class AbortingLongRunningCommand(ResponseCommand):
......@@ -126,12 +109,6 @@ class LongRunningCommandBaseTestDevice(SKABaseDevice):
def do(self, argin):
"""Abort."""
class AbortingTask(QueueTask):
"""Abort."""
def do(self):
"""Abort."""
retries = 45
while (not self.aborting_event.is_set()) and retries > 0:
retries -= 1
......@@ -149,11 +126,6 @@ class LongRunningCommandBaseTestDevice(SKABaseDevice):
f"NonAbortingTask Aborted {argin}",
)
self.logger.info("In AbortingLongRunningCommand")
unique_id = self.target.enqueue(AbortingTask(argin, logger=self.logger))
return ResultCode.OK, unique_id
@command(
dtype_in=float,
dtype_out="DevVarStringArray",
......@@ -162,26 +134,16 @@ class LongRunningCommandBaseTestDevice(SKABaseDevice):
def AbortingLongRunning(self, argin):
"""AbortingLongRunning."""
handler = self.get_command_object("AbortingLongRunning")
(return_code, message) = handler(argin)
(return_code, message) = self.component_manager.enqueue(handler, argin)
return f"{return_code}", f"{message}"
class LongRunningExceptionCommand(ResponseCommand):
"""The command class for the LongRunningException command."""
def do(self):
"""Throw an exception."""
class ExcTask(QueueTask):
"""Throw an exception."""
def do(self):
"""Throw an exception."""
raise Exception("An error occurred")
unique_id = self.target.enqueue(ExcTask())
return ResultCode.OK, unique_id
@command(
dtype_in=None,
dtype_out="DevVarStringArray",
......@@ -190,7 +152,7 @@ class LongRunningCommandBaseTestDevice(SKABaseDevice):
def LongRunningException(self):
"""Command that queues a task that raises an exception."""
handler = self.get_command_object("LongRunningException")
(return_code, message) = handler()
(return_code, message) = self.component_manager.enqueue(handler)
return f"{return_code}", f"{message}"
class TestProgressCommand(ResponseCommand):
......@@ -198,18 +160,10 @@ class LongRunningCommandBaseTestDevice(SKABaseDevice):
def do(self, argin):
"""Do the task."""
class ProgressTask(QueueTask):
"""A task that updates its progress."""
def do(self):
"""Update progress."""
for progress in [1, 25, 50, 74, 100]:
self.update_progress(f"{progress}")
time.sleep(self.args[0])
unique_id = self.target.enqueue(ProgressTask(argin))
return ResultCode.OK, unique_id
time.sleep(argin)
return ResultCode.OK, "OK"
@command(
dtype_in=float,
......@@ -219,7 +173,7 @@ class LongRunningCommandBaseTestDevice(SKABaseDevice):
def TestProgress(self, argin):
"""Command to test the progress indicator."""
handler = self.get_command_object("TestProgress")
(return_code, message) = handler(argin)
(return_code, message) = self.component_manager.enqueue(handler, argin)
return f"{return_code}", f"{message}"
......
......@@ -15,7 +15,6 @@ from reference_base_device import (
)
from ska_tango_base.base.task_queue_manager import TaskResult
from ska_tango_base.commands import ResultCode
from ska_tango_base.control_model import AdminMode
class TestCommands:
......@@ -30,11 +29,7 @@ class TestCommands:
def test_short_command(self):
"""Test a simple command."""
for class_name in [BlockingBaseDevice, AsyncBaseDevice]:
with DeviceTestContext(
class_name,
process=True,
memorized={"adminMode": str(AdminMode.ONLINE.value)},
) as proxy:
with DeviceTestContext(class_name, process=True) as proxy:
proxy.Short(1)
# Wait for a result, if the task does not abort, we'll time out here
while not proxy.longRunningCommandResult:
......@@ -42,7 +37,7 @@ class TestCommands:
result = TaskResult.from_task_result(proxy.longRunningCommandResult)
assert result.result_code == ResultCode.OK
assert result.get_task_unique_id().id_task_name == "SimpleTask"
assert result.get_task_unique_id().id_task_name == "ShortCommand"
@pytest.mark.forked
@pytest.mark.timeout(5)
......@@ -56,7 +51,10 @@ class TestCommands:
pass
result = TaskResult.from_task_result(proxy.longRunningCommandResult)
assert result.result_code == ResultCode.OK
assert result.get_task_unique_id().id_task_name == "NonAbortingTask"
assert (
result.get_task_unique_id().id_task_name
== "NonAbortingLongRunningCommand"
)
@pytest.mark.forked
@pytest.mark.timeout(5)
......@@ -67,7 +65,7 @@ class TestCommands:
AbortCommands after that makes no sense.
"""
with DeviceTestContext(AsyncBaseDevice, process=True) as proxy:
_, unique_id = proxy.AbortingLongRunning(0.5)
unique_id, _ = proxy.AbortingLongRunning(0.5)
# Wait for the task to be in progress
while not proxy.longRunningCommandStatus:
pass
......@@ -87,7 +85,7 @@ class TestCommands:
"""Test the task that throws an error."""
for class_name in [BlockingBaseDevice, AsyncBaseDevice]:
with DeviceTestContext(class_name, process=True) as proxy:
_, unique_id = proxy.LongRunningException()
unique_id, _ = proxy.LongRunningException()
while not proxy.longRunningCommandResult:
pass
result = TaskResult.from_task_result(proxy.longRunningCommandResult)
......@@ -132,11 +130,11 @@ def test_callbacks():
# longRunningCommandsInQueue
attribute_values = [arg[1] for arg in called_args]
assert len(attribute_values[0]) == 1
assert attribute_values[0] == ("ProgressTask",)
assert attribute_values[0] == ("TestProgressCommand",)
# longRunningCommandIDsInQueue
assert len(attribute_values[1]) == 1
assert attribute_values[1][0].endswith("ProgressTask")
assert attribute_values[1][0].endswith("TestProgressCommand")
# longRunningCommandsInQueue
assert not attribute_values[2]
......@@ -146,19 +144,19 @@ def test_callbacks():
# longRunningCommandStatus
assert len(attribute_values[4]) == 2
assert attribute_values[4][0].endswith("ProgressTask")
assert attribute_values[4][0].endswith("TestProgressCommand")
assert attribute_values[4][1] == "IN_PROGRESS"
# longRunningCommandProgress
for (index, progress) in zip(range(5, 9), ["1", "25", "50", "74", "100"]):
assert len(attribute_values[index]) == 2
assert attribute_values[index][0].endswith("ProgressTask")
assert attribute_values[index][0].endswith("TestProgressCommand")
assert attribute_values[index][1] == progress
# longRunningCommandResult
assert len(attribute_values[10]) == 3
tr = TaskResult.from_task_result(attribute_values[10])
assert tr.get_task_unique_id().id_task_name == "ProgressTask"
assert tr.get_task_unique_id().id_task_name == "TestProgressCommand"
tr.result_code == ResultCode.OK
tr.task_result == "None"
......@@ -168,7 +166,7 @@ def test_callbacks():
def test_events():
"""Testing the events.
NOTE: Adding more than 2 event subscriptions leads to inconsistent results.
NOTE: Adding more than 1 event subscriptions leads to inconsistent results.
Sometimes misses events.
Full callback tests (where the push events are triggered) are covered
......@@ -176,30 +174,23 @@ def test_events():
"""
with DeviceTestContext(AsyncBaseDevice, process=True) as proxy:
progress_events = EventCallback(fd=StringIO())
ids_in_queue_events = EventCallback(fd=StringIO())
progress_id = proxy.subscribe_event(
proxy.subscribe_event(
"longRunningCommandProgress",
EventType.CHANGE_EVENT,
progress_events,
wait=True,
)
ids_id = proxy.subscribe_event(
"longRunningCommandIDsInQueue",
EventType.CHANGE_EVENT,
ids_in_queue_events,
wait=True,
)
proxy.TestProgress(0.5)
proxy.TestProgress(0.2)
# Wait for task to finish
while not proxy.longRunningCommandResult:
time.sleep(0.1)
# Wait for events
# Wait for progress events
while not progress_events.get_events():
time.sleep(0.1)
time.sleep(0.5)
progress_event_values = [
event.attr_value.value
......@@ -208,13 +199,3 @@ def test_events():
]
for index, progress in enumerate(["1", "25", "50", "74", "100"]):
assert progress_event_values[index][1] == progress
ids_in_queue_events_values = [
event.attr_value.value
for event in ids_in_queue_events.get_events()
if event.attr_value and event.attr_value.value
]
assert len(ids_in_queue_events_values) == 1
assert ids_in_queue_events_values[0][0].endswith("ProgressTask")
proxy.unsubscribe_event(progress_id)
proxy.unsubscribe_event(ids_id)
......@@ -8,10 +8,10 @@ from ska_tango_base.commands import ResultCode
from ska_tango_base.base.task_queue_manager import (
QueueManager,
TaskResult,
QueueTask,
TaskState,
)
from ska_tango_base.base.component_manager import BaseComponentManager
from ska_tango_base.commands import BaseCommand
logger = logging.getLogger(__name__)
......@@ -33,13 +33,13 @@ def progress_task():
"""Fixture for a test that throws an exception."""
def get_task():
class ProgressTask(QueueTask):
class ProgressTask(BaseCommand):
def do(self):
for i in range(100):
self.update_progress(str(i))
time.sleep(0.5)
return ProgressTask()
return ProgressTask(target=None)
return get_task
......@@ -49,11 +49,11 @@ def exc_task():
"""Fixture for a test that throws an exception."""
def get_task():
class ExcTask(QueueTask):
class ExcTask(BaseCommand):
def do(self):
raise Exception("An error occurred")
return ExcTask()
return ExcTask(target=None)
return get_task
......@@ -63,11 +63,11 @@ def slow_task():
"""Fixture for a test that takes long."""
def get_task():
class SlowTask(QueueTask):
class SlowTask(BaseCommand):
def do(self):
time.sleep(2)
return SlowTask()
return SlowTask(target=None)
return get_task
......@@ -77,13 +77,11 @@ def simple_task():
"""Fixture for a very simple task."""
def get_task():
class SimpleTask(QueueTask):
def do(self):
num_one = self.args[0]
num_two = self.kwargs.get("num_two")
return num_one + num_two
class SimpleTask(BaseCommand):
def do(self, argin):
return argin + 2
return SimpleTask(2, num_two=3)
return SimpleTask(2)
return get_task
......@@ -93,13 +91,13 @@ def abort_task():
"""Fixture for a task that aborts."""
def get_task():
class AbortTask(QueueTask):
def do(self):
sleep_time = self.args[0]
class AbortTask(BaseCommand):
def do(self, argin):
sleep_time = argin
while not self.aborting_event.is_set():
time.sleep(sleep_time)
return AbortTask(0.2)
return AbortTask(target=None)
return get_task
......@@ -109,30 +107,18 @@ def stop_task():
"""Fixture for a task that stops."""
def get_task():
class StopTask(QueueTask):
class StopTask(BaseCommand):
def do(self):
assert not self.stopping_event.is_set()
while not self.stopping_event.is_set():
pass
return StopTask()
return StopTask(target=None)
return get_task
class TestQueueTask:
"""Test QueueTask."""
def test_simple(self, simple_task):
"""Test simple task."""
assert simple_task().do() == 5
def test_exception(self, exc_task):
"""Test that exception is thrown."""
with pytest.raises(Exception):
exc_task().do()
@pytest.mark.forked
class TestQueueManager:
"""General QueueManager checks."""
......@@ -148,6 +134,7 @@ class TestQueueManager:
worker.stopping_event.set()
@pytest.mark.forked
class TestQueueManagerTasks:
"""QueueManager checks for tasks executed."""
......@@ -155,18 +142,19 @@ class TestQueueManagerTasks:
def test_task_ids(self, simple_task):
"""Check ids."""
qm = QueueManager(max_queue_size=5, num_workers=2, logger=logger)
unique_id_one = qm.enqueue_task(simple_task())
unique_id_two = qm.enqueue_task(simple_task())
unique_id_one, result_code = qm.enqueue_task(simple_task(), 2)
unique_id_two, _ = qm.enqueue_task(simple_task(), 2)
assert unique_id_one.endswith("SimpleTask")
assert unique_id_one != unique_id_two
assert result_code == ResultCode.QUEUED
@pytest.mark.timeout(5)
def test_task_is_executed(self, simple_task):
"""Check that tasks are executed."""
with patch.object(QueueManager, "result_callback") as my_cb:
qm = QueueManager(max_queue_size=5, num_workers=2, logger=logger)
unique_id_one = qm.enqueue_task(simple_task())
unique_id_two = qm.enqueue_task(simple_task())
unique_id_one, _ = qm.enqueue_task(simple_task(), 3)
unique_id_two, _ = qm.enqueue_task(simple_task(), 3)
while my_cb.call_count != 2:
time.sleep(0.5)
......@@ -191,7 +179,7 @@ class TestQueueManagerTasks:
add_task_one = simple_task()
exc_task = exc_task()
qm.enqueue_task(add_task_one)
qm.enqueue_task(add_task_one, 3)
while not qm.task_result:
time.sleep(0.5)
task_result = TaskResult.from_task_result(qm.task_result)
......@@ -252,7 +240,7 @@ class TestQueueManagerTasks:
# No Queue
qm = QueueManager(max_queue_size=0, num_workers=1, logger=logger)
assert len(qm._threads) == 0
res = qm.enqueue_task(simple_task())
res, _ = qm.enqueue_task(simple_task(), 3)
assert res.endswith(expected_name)
assert qm.task_result[0].endswith(expected_name)
assert int(qm.task_result[1]) == expected_result_code
......@@ -260,7 +248,7 @@ class TestQueueManagerTasks:
# Queue
qm = QueueManager(max_queue_size=2, num_workers=1, logger=logger)
res = qm.enqueue_task(simple_task())
res, _ = qm.enqueue_task(simple_task(), 3)
assert res.endswith(expected_name)
# Wait for the task to be picked up
......@@ -284,21 +272,20 @@ class TestQueueManagerTasks:
)
unique_ids = []
for _ in range(4):
unique_id = qm.enqueue_task(slow_task())
unique_id, _ = qm.enqueue_task(slow_task())
unique_ids.append(unique_id)
# Wait for a item on the queue
while not qm.task_ids_in_queue:
pass
while not qm.task_result:
# Wait for the queue to empty
while not qm.task_status:
pass
# Wait for last task to finish
while (
unique_ids[-1] != TaskResult.from_task_result(qm.task_result).unique_id
):
pass
# Wait for all the callbacks to fire
while len(call_back_func.call_args_list) < 24:
time.sleep(0.1)
all_passed_params = [a_call[0] for a_call in call_back_func.call_args_list]
tasks_in_queue = [
......@@ -339,7 +326,7 @@ class TestQueueManagerTasks:
def test_task_get_state_completed(self, simple_task):
"""Test the QueueTask get state is completed."""
qm = QueueManager(max_queue_size=8, num_workers=2, logger=logger)
unique_id_one = qm.enqueue_task(simple_task())
unique_id_one, _ = qm.enqueue_task(simple_task(), 3)
while not qm.task_result:
pass
assert qm.get_task_state(unique_id=unique_id_one) == TaskState.COMPLETED
......@@ -347,16 +334,16 @@ class TestQueueManagerTasks:
def test_task_get_state_in_queued(self, slow_task):
"""Test the QueueTask get state is queued."""
qm = QueueManager(max_queue_size=8, num_workers=1, logger=logger)
qm.enqueue_task(slow_task())
qm.enqueue_task(slow_task())
unique_id_last = qm.enqueue_task(slow_task())
qm.enqueue_task(slow_task(), 2)
qm.enqueue_task(slow_task(), 2)
unique_id_last, _ = qm.enqueue_task(slow_task())
assert qm.get_task_state(unique_id=unique_id_last) == TaskState.QUEUED
def test_task_get_state_in_progress(self, progress_task):
"""Test the QueueTask get state is in progress."""
qm = QueueManager(max_queue_size=8, num_workers=2, logger=logger)
unique_id_one = qm.enqueue_task(progress_task())
unique_id_one, _ = qm.enqueue_task(progress_task())
while not qm.task_progress:
pass
......@@ -368,6 +355,7 @@ class TestQueueManagerTasks:
assert qm.get_task_state(unique_id="non_existing_id") == TaskState.NOT_FOUND
@pytest.mark.forked
class TestQueueManagerExit:
"""Test the stopping and aborting."""
......@@ -381,7 +369,7 @@ class TestQueueManagerExit:
)
cm = BaseComponentManager(op_state_model=None, queue_manager=qm, logger=None)
cm.enqueue(abort_task())
cm.enqueue(abort_task(), 0.1)
# Wait for the command to start
while not qm.task_status:
......@@ -406,7 +394,7 @@ class TestQueueManagerExit:
# Load up some tasks that should be aborted
cm.enqueue(slow_task())
cm.enqueue(slow_task())
unique_id = cm.enqueue(slow_task())
unique_id, _ = cm.enqueue(slow_task())
while True:
tr = TaskResult.from_task_result(qm.task_result)
......@@ -419,7 +407,7 @@ class TestQueueManagerExit:
assert not qm.is_aborting
# Wait for my slow command to finish
unique_id = cm.enqueue(slow_task())
unique_id, _ = cm.enqueue(slow_task())
while True:
tr = TaskResult.from_task_result(qm.task_result)
if tr.unique_id == unique_id:
......@@ -471,6 +459,7 @@ class TestQueueManagerExit:
del cm
@pytest.mark.forked
class TestComponentManager:
"""Tests for the component manager."""
......@@ -481,6 +470,7 @@ class TestComponentManager:
assert cm.queue_manager.task_ids_in_queue == ()
@pytest.mark.forked
class TestStress:
"""Stress test the queue mananger."""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment