From 3bab5676f71094e541f7e9039641296352e630a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B6rn=20K=C3=BCnsem=C3=B6ller?= <jkuensem@physik.uni-bielefeld.de> Date: Thu, 14 Dec 2023 13:58:57 +0100 Subject: [PATCH] TMSS-2637: add test coverage --- SAS/TMSS/backend/src/tmss/tmssapp/tasks.py | 13 +++-- SAS/TMSS/backend/test/t_scheduling_units.py | 63 ++++++++++++++++++++- 2 files changed, 68 insertions(+), 8 deletions(-) diff --git a/SAS/TMSS/backend/src/tmss/tmssapp/tasks.py b/SAS/TMSS/backend/src/tmss/tmssapp/tasks.py index 24855516295..810d42c38df 100644 --- a/SAS/TMSS/backend/src/tmss/tmssapp/tasks.py +++ b/SAS/TMSS/backend/src/tmss/tmssapp/tasks.py @@ -215,7 +215,7 @@ def copy_task_draft(task_draft: models.TaskDraft, remove_lofar1_stations=False, with transaction.atomic(): specifications_doc = copy.deepcopy(task_draft.specifications_doc) if remove_lofar1_stations or remove_lofar2_stations: - remove_stations_from_task_specifications_doc(specifications_doc, remove_lofar1_stations, remove_lofar2_stations) + _remove_stations_from_task_specifications_doc(specifications_doc, remove_lofar1_stations, remove_lofar2_stations) task_draft_copy = models.TaskDraft.objects.create(name="%s (Copy)" % (task_draft.name,), description="%s (Copy from task_draft id=%s)" % (task_draft.description, task_draft.id), short_description=task_draft.short_description, @@ -248,7 +248,7 @@ def copy_task_blueprint_to_task_draft(task_blueprint: models.TaskBlueprint, remo with transaction.atomic(): specifications_doc = copy.deepcopy(task_blueprint.specifications_doc) if remove_lofar1_stations or remove_lofar2_stations: - remove_stations_from_task_specifications_doc(specifications_doc, remove_lofar1_stations, remove_lofar2_stations) + _remove_stations_from_task_specifications_doc(specifications_doc, remove_lofar1_stations, remove_lofar2_stations) task_draft_copy = models.TaskDraft.objects.create(name="%s (Copy)" % (task_blueprint.name,), description="%s (Copy from task_blueprint id=%s)" % (task_blueprint.description, task_blueprint.id), short_description=task_blueprint.short_description, @@ -290,6 +290,11 @@ def update_task_graph_from_specifications_doc(scheduling_unit_draft: models.Sche """ logger.debug("update_task_graph_from_specifications_doc(scheduling_unit_draft.id=%s, name='%s') ...", scheduling_unit_draft.pk, scheduling_unit_draft.name) + # remove stations before validation + if remove_lofar1_stations or remove_lofar2_stations: + for task_definition in specifications_doc.get("tasks", {}).values(): + _remove_stations_from_task_specifications_doc(task_definition.get("specifications_doc", {}), remove_lofar1_stations, remove_lofar2_stations) + # make sure the given specifications_doc validates scheduling_unit_draft.specifications_template.validate_document(specifications_doc) @@ -363,8 +368,6 @@ def update_task_graph_from_specifications_doc(scheduling_unit_draft: models.Sche task_template = models.TaskTemplate.get_version_or_latest(name=task_template_name, version=task_template_version) task_specifications_doc = task_definition.get("specifications_doc", {}) - if remove_lofar1_stations or remove_lofar2_stations: - remove_stations_from_task_specifications_doc(task_specifications_doc, remove_lofar1_stations, remove_lofar2_stations) task_specifications_doc = task_template.add_defaults_to_json_object_for_schema(task_specifications_doc) logger.debug("creating/updating task draft... task_name='%s', task_template_name='%s', task_template_version=%s", task_name, task_template_name, task_template_version) @@ -1123,7 +1126,7 @@ def enough_stations_available_for_scheduling_unit(scheduling_unit: SchedulingUni return all(enough_stations_available_for_task(obs_task, unavailable_stations) for obs_task in scheduling_unit.observation_tasks.all()) -def remove_stations_from_task_specifications_doc(specifications_doc, remove_lofar1_stations=False, remove_lofar2_stations=False): +def _remove_stations_from_task_specifications_doc(specifications_doc, remove_lofar1_stations=False, remove_lofar2_stations=False): """ This function can be used to sanitize task specs from LOFAR 1 or LOFAR 2 stations, e.g. when copying older specs that contains a set of stations that is now a mix of LOFAR 1 and 2, and the copy should only contain one type. diff --git a/SAS/TMSS/backend/test/t_scheduling_units.py b/SAS/TMSS/backend/test/t_scheduling_units.py index 1f998c43e6a..14b1507d597 100644 --- a/SAS/TMSS/backend/test/t_scheduling_units.py +++ b/SAS/TMSS/backend/test/t_scheduling_units.py @@ -55,14 +55,15 @@ from lofar.sas.tmss.test.tmss_test_data_django_models import * rest_data_creator = tmss_test_env.create_test_data_creator() from lofar.sas.tmss.tmss.tmssapp import models -from lofar.sas.tmss.tmss.exceptions import SchemaValidationException, ObsoleteTemplateException, ValidationException, ObsoleteValidationException +from lofar.sas.tmss.tmss.exceptions import SchemaValidationException, ObsoleteTemplateException, ValidationException, ObsoleteValidationException, RuleValidationException import requests -from lofar.sas.tmss.tmss.tmssapp.tasks import create_scheduling_unit_blueprint_and_tasks_and_subtasks_from_scheduling_unit_draft, create_scheduling_unit_blueprint_from_scheduling_unit_draft, update_task_blueprint_graph_from_draft, update_task_graph_from_specifications_doc, create_scheduling_unit_draft_from_observing_strategy_template, mark_task_blueprint_as_obsolete, cancel_task_blueprint, schedule_independent_subtasks_in_task_blueprint +from lofar.sas.tmss.tmss.tmssapp.tasks import create_scheduling_unit_blueprint_and_tasks_and_subtasks_from_scheduling_unit_draft, create_scheduling_unit_blueprint_from_scheduling_unit_draft, update_task_blueprint_graph_from_draft, update_task_graph_from_specifications_doc, create_scheduling_unit_draft_from_observing_strategy_template, mark_task_blueprint_as_obsolete, cancel_task_blueprint, schedule_independent_subtasks_in_task_blueprint, create_scheduling_unit_draft_from_scheduling_unit_blueprint from lofar.sas.tmss.tmss.tmssapp.subtasks import schedule_subtask, cancel_subtask, wait_for_subtask_status from lofar.sas.tmss.test.test_utils import set_subtask_state_following_allowed_transitions from lofar.messaging.messagebus import BusListenerJanitor +from lofar.sas.tmss.tmss.tmssapp.conversions import get_lofar2_stations logger = logging.getLogger(__name__) logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s', level=logging.INFO) @@ -495,11 +496,67 @@ class SchedulingUnitBlueprintStateTest(unittest.TestCase): draft = create_scheduling_unit_draft_from_observing_strategy_template(strategy_template, scheduling_set=models.SchedulingSet.objects.create(**SchedulingSet_test_data()), specifications_doc_overrides=override_doc) # the template contains international stations and a custom group by default, both including DE605. - self.assertNotIn("DE605", draft.specifications_doc) + self.assertNotIn("DE605", str(draft.specifications_doc)) # assert the draft specs contain exactly the specified station groups: self.assertEqual(draft.specifications_doc["tasks"]["Target Observation"]["specifications_doc"]["station_configuration"]["station_groups"], station_groups) + def test_remove_lofar2_station_while_create_scheduling_unit_draft_from_scheduling_unit_blueprint(self): # TMSS-2637 + # create scheduling unit with several LOFAR1 stations + strategy_template = models.SchedulingUnitObservingStrategyTemplate.get_latest(name="IM HBA - 1 Beam") + station_groups = [{"stations": ["CS002","CS003","CS004","CS005","DE605"], "max_nr_missing": 1}] + override_doc = {"tasks": {"Target Observation": {"specifications_doc": {"station_configuration": {"station_groups": station_groups}}}}} + draft = create_scheduling_unit_draft_from_observing_strategy_template(strategy_template, scheduling_set=models.SchedulingSet.objects.create(**SchedulingSet_test_data()), specifications_doc_overrides=override_doc) + blueprint = create_scheduling_unit_blueprint_and_tasks_and_subtasks_from_scheduling_unit_draft(draft) + + # check that the blueprint contains a particular station and all its tasks validate + self.assertIn("DE605", str(blueprint.specifications_doc)) + for task in blueprint.task_blueprints.all(): + task.validate_specifications_doc() + + lofar2_stations = get_lofar2_stations() + self.assertNotIn('DE605', lofar2_stations) + + # mark that station LOFAR2 + station_schema_template = models.CommonSchemaTemplate.get_latest(name="stations") + groups = station_schema_template.schema['definitions']['station_group']['anyOf'] + for group in groups: + if group['title'].upper() == 'LOFAR2': + group['properties']['stations']['enum'][0].append('DE605') + station_schema_template.save() + get_lofar2_stations.cache_clear() + lofar2_stations = get_lofar2_stations() + self.assertIn('DE605', lofar2_stations) + + # make sure that the previously blueprinted tasks don't validate any more + with self.assertRaises(RuleValidationException): + for task in blueprint.task_blueprints.all(): + task.validate_specifications_doc() + + # start a subtask so that the unit is advance enough to skip rule-based validation on specs read access + set_subtask_state_following_allowed_transitions(blueprint.subtasks.first(), 'started') + blueprint.refresh_from_db() + self.assertEqual(blueprint.status.value, models.SchedulingUnitStatus.Choices.OBSERVING.value) + + # try to create a draft copy, which should fail due to the invalid specs + with self.assertRaises(ValidationException): + create_scheduling_unit_draft_from_scheduling_unit_blueprint(blueprint) + + # try again, but this time remove lofar2 stations when creating the draft so that it can be blueprinted + draft_copy = create_scheduling_unit_draft_from_scheduling_unit_blueprint(blueprint, remove_lofar2_stations=True) + blueprint_from_copy = create_scheduling_unit_blueprint_and_tasks_and_subtasks_from_scheduling_unit_draft(draft_copy) + + # check that the new blueprint does not contain the LOFAR2 station and all its tasks validate + self.assertNotIn("DE605", str(blueprint_from_copy.specifications_doc)) + for task in blueprint_from_copy.task_blueprints.all(): + task.validate_specifications_doc() + + # revert the list of LOFAR2 stations, so other tests still work + for group in groups: + if group['title'].upper() == 'LOFAR2': + group['properties']['stations']['enum'][0].remove('DE605') + station_schema_template.save() + get_lofar2_stations.cache_clear() class TestFlatStations(unittest.TestCase): """ -- GitLab