Skip to content
Snippets Groups Projects
Unverified Commit a725df0a authored by SKAJohanVenter's avatar SKAJohanVenter
Browse files

SAR-276 Added is_allowed tests

parent 13b3181f
Branches
No related tags found
No related merge requests found
...@@ -16,6 +16,40 @@ from ska_tango_base.commands import BaseCommand ...@@ -16,6 +16,40 @@ from ska_tango_base.commands import BaseCommand
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@pytest.fixture
def not_allowed_task():
"""Fixture for a test that throws an exception."""
def get_task():
class NotAllowedTask(BaseCommand):
def do(self):
pass
def is_allowed(self):
return False
return NotAllowedTask(target=None)
return get_task
@pytest.fixture
def not_allowed_exc_task():
"""Fixture for a test that throws an exception."""
def get_task():
class NotAllowedErrorTask(BaseCommand):
def do(self):
pass
def is_allowed(self, raise_if_disallowed=True):
raise Exception("Not allowed")
return NotAllowedErrorTask(target=None)
return get_task
@pytest.fixture @pytest.fixture
def progress_task(): def progress_task():
"""Fixture for a test that throws an exception.""" """Fixture for a test that throws an exception."""
...@@ -507,3 +541,57 @@ class TestStress: ...@@ -507,3 +541,57 @@ class TestStress:
while qm._work_queue.qsize(): while qm._work_queue.qsize():
time.sleep(0.1) time.sleep(0.1)
del qm del qm
class TestNotAllowed:
"""Tests for `is_allowed`."""
@pytest.mark.timeout(5)
def test_not_allowed(self, not_allowed_task):
"""Check is_allowed."""
results = []
def catch_updates(name, result):
if name == "longRunningCommandResult":
tr = TaskResult.from_task_result(result)
results.append(tr.result_code)
qm = QueueManager(
max_queue_size=2,
num_workers=2,
logger=logger,
push_change_event=catch_updates,
)
qm.enqueue_task(not_allowed_task())
while ResultCode.NOT_ALLOWED not in results:
time.sleep(0.5)
@pytest.mark.timeout(5)
def test_not_allowed_exc(self, not_allowed_exc_task):
"""Check is_allowed error."""
results = []
def catch_updates(name, result):
if name == "longRunningCommandResult":
tr = TaskResult.from_task_result(result)
results.append(
(
tr.result_code,
tr.task_result,
)
)
qm = QueueManager(
max_queue_size=2,
num_workers=2,
logger=logger,
push_change_event=catch_updates,
)
qm.enqueue_task(not_allowed_exc_task())
while not results:
time.sleep(0.5)
assert ResultCode.FAILED == results[0][0]
assert "Error: Not allowed Traceback (most recent call last)" in results[0][1]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment