diff --git a/tests/long_running_tasks/test_task_queue_manager.py b/tests/long_running_tasks/test_task_queue_manager.py index 32ae62ffff5af2bce987c9cf8cf87ed9b9950fd3..1907102350167e0425272fac10460b5cc972ca25 100644 --- a/tests/long_running_tasks/test_task_queue_manager.py +++ b/tests/long_running_tasks/test_task_queue_manager.py @@ -16,6 +16,40 @@ from ska_tango_base.commands import BaseCommand logger = logging.getLogger(__name__) +@pytest.fixture +def not_allowed_task(): + """Fixture for a test that throws an exception.""" + + def get_task(): + class NotAllowedTask(BaseCommand): + def do(self): + pass + + def is_allowed(self): + return False + + return NotAllowedTask(target=None) + + return get_task + + +@pytest.fixture +def not_allowed_exc_task(): + """Fixture for a test that throws an exception.""" + + def get_task(): + class NotAllowedErrorTask(BaseCommand): + def do(self): + pass + + def is_allowed(self, raise_if_disallowed=True): + raise Exception("Not allowed") + + return NotAllowedErrorTask(target=None) + + return get_task + + @pytest.fixture def progress_task(): """Fixture for a test that throws an exception.""" @@ -507,3 +541,57 @@ class TestStress: while qm._work_queue.qsize(): time.sleep(0.1) del qm + + +class TestNotAllowed: + """Tests for `is_allowed`.""" + + @pytest.mark.timeout(5) + def test_not_allowed(self, not_allowed_task): + """Check is_allowed.""" + results = [] + + def catch_updates(name, result): + if name == "longRunningCommandResult": + tr = TaskResult.from_task_result(result) + results.append(tr.result_code) + + qm = QueueManager( + max_queue_size=2, + num_workers=2, + logger=logger, + push_change_event=catch_updates, + ) + qm.enqueue_task(not_allowed_task()) + + while ResultCode.NOT_ALLOWED not in results: + time.sleep(0.5) + + @pytest.mark.timeout(5) + def test_not_allowed_exc(self, not_allowed_exc_task): + """Check is_allowed error.""" + results = [] + + def catch_updates(name, result): + if name == "longRunningCommandResult": + tr = TaskResult.from_task_result(result) + results.append( + ( + tr.result_code, + tr.task_result, + ) + ) + + qm = QueueManager( + max_queue_size=2, + num_workers=2, + logger=logger, + push_change_event=catch_updates, + ) + qm.enqueue_task(not_allowed_exc_task()) + + while not results: + time.sleep(0.5) + + assert ResultCode.FAILED == results[0][0] + assert "Error: Not allowed Traceback (most recent call last)" in results[0][1]