From 3bab5676f71094e541f7e9039641296352e630a3 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?J=C3=B6rn=20K=C3=BCnsem=C3=B6ller?=
 <jkuensem@physik.uni-bielefeld.de>
Date: Thu, 14 Dec 2023 13:58:57 +0100
Subject: [PATCH] TMSS-2637: add test coverage

---
 SAS/TMSS/backend/src/tmss/tmssapp/tasks.py  | 13 +++--
 SAS/TMSS/backend/test/t_scheduling_units.py | 63 ++++++++++++++++++++-
 2 files changed, 68 insertions(+), 8 deletions(-)

diff --git a/SAS/TMSS/backend/src/tmss/tmssapp/tasks.py b/SAS/TMSS/backend/src/tmss/tmssapp/tasks.py
index 24855516295..810d42c38df 100644
--- a/SAS/TMSS/backend/src/tmss/tmssapp/tasks.py
+++ b/SAS/TMSS/backend/src/tmss/tmssapp/tasks.py
@@ -215,7 +215,7 @@ def copy_task_draft(task_draft: models.TaskDraft, remove_lofar1_stations=False,
     with transaction.atomic():
         specifications_doc = copy.deepcopy(task_draft.specifications_doc)
         if remove_lofar1_stations or remove_lofar2_stations:
-            remove_stations_from_task_specifications_doc(specifications_doc, remove_lofar1_stations, remove_lofar2_stations)
+            _remove_stations_from_task_specifications_doc(specifications_doc, remove_lofar1_stations, remove_lofar2_stations)
         task_draft_copy = models.TaskDraft.objects.create(name="%s (Copy)" % (task_draft.name,),
                                                           description="%s (Copy from task_draft id=%s)" % (task_draft.description, task_draft.id),
                                                           short_description=task_draft.short_description,
@@ -248,7 +248,7 @@ def copy_task_blueprint_to_task_draft(task_blueprint: models.TaskBlueprint, remo
     with transaction.atomic():
         specifications_doc = copy.deepcopy(task_blueprint.specifications_doc)
         if remove_lofar1_stations or remove_lofar2_stations:
-            remove_stations_from_task_specifications_doc(specifications_doc, remove_lofar1_stations, remove_lofar2_stations)
+            _remove_stations_from_task_specifications_doc(specifications_doc, remove_lofar1_stations, remove_lofar2_stations)
         task_draft_copy = models.TaskDraft.objects.create(name="%s (Copy)" % (task_blueprint.name,),
                                                           description="%s (Copy from task_blueprint id=%s)" % (task_blueprint.description, task_blueprint.id),
                                                           short_description=task_blueprint.short_description,
@@ -290,6 +290,11 @@ def update_task_graph_from_specifications_doc(scheduling_unit_draft: models.Sche
     """
     logger.debug("update_task_graph_from_specifications_doc(scheduling_unit_draft.id=%s, name='%s') ...", scheduling_unit_draft.pk, scheduling_unit_draft.name)
 
+    # remove stations before validation
+    if remove_lofar1_stations or remove_lofar2_stations:
+        for task_definition in specifications_doc.get("tasks", {}).values():
+            _remove_stations_from_task_specifications_doc(task_definition.get("specifications_doc", {}), remove_lofar1_stations, remove_lofar2_stations)
+
     # make sure the given specifications_doc validates
     scheduling_unit_draft.specifications_template.validate_document(specifications_doc)
 
@@ -363,8 +368,6 @@ def update_task_graph_from_specifications_doc(scheduling_unit_draft: models.Sche
             task_template = models.TaskTemplate.get_version_or_latest(name=task_template_name, version=task_template_version)
 
             task_specifications_doc = task_definition.get("specifications_doc", {})
-            if remove_lofar1_stations or remove_lofar2_stations:
-                remove_stations_from_task_specifications_doc(task_specifications_doc, remove_lofar1_stations, remove_lofar2_stations)
             task_specifications_doc = task_template.add_defaults_to_json_object_for_schema(task_specifications_doc)
 
             logger.debug("creating/updating task draft... task_name='%s', task_template_name='%s', task_template_version=%s", task_name, task_template_name, task_template_version)
@@ -1123,7 +1126,7 @@ def enough_stations_available_for_scheduling_unit(scheduling_unit: SchedulingUni
     return all(enough_stations_available_for_task(obs_task, unavailable_stations) for obs_task in scheduling_unit.observation_tasks.all())
 
 
-def remove_stations_from_task_specifications_doc(specifications_doc, remove_lofar1_stations=False, remove_lofar2_stations=False):
+def _remove_stations_from_task_specifications_doc(specifications_doc, remove_lofar1_stations=False, remove_lofar2_stations=False):
     """
     This function can be used to sanitize task specs from LOFAR 1 or LOFAR 2 stations, e.g. when copying older specs
     that contains a set of stations that is now a mix of LOFAR 1 and 2, and the copy should only contain one type.
diff --git a/SAS/TMSS/backend/test/t_scheduling_units.py b/SAS/TMSS/backend/test/t_scheduling_units.py
index 1f998c43e6a..14b1507d597 100644
--- a/SAS/TMSS/backend/test/t_scheduling_units.py
+++ b/SAS/TMSS/backend/test/t_scheduling_units.py
@@ -55,14 +55,15 @@ from lofar.sas.tmss.test.tmss_test_data_django_models import *
 rest_data_creator = tmss_test_env.create_test_data_creator()
 
 from lofar.sas.tmss.tmss.tmssapp import models
-from lofar.sas.tmss.tmss.exceptions import SchemaValidationException, ObsoleteTemplateException, ValidationException, ObsoleteValidationException
+from lofar.sas.tmss.tmss.exceptions import SchemaValidationException, ObsoleteTemplateException, ValidationException, ObsoleteValidationException, RuleValidationException
 
 import requests
 
-from lofar.sas.tmss.tmss.tmssapp.tasks import create_scheduling_unit_blueprint_and_tasks_and_subtasks_from_scheduling_unit_draft, create_scheduling_unit_blueprint_from_scheduling_unit_draft, update_task_blueprint_graph_from_draft, update_task_graph_from_specifications_doc, create_scheduling_unit_draft_from_observing_strategy_template, mark_task_blueprint_as_obsolete, cancel_task_blueprint, schedule_independent_subtasks_in_task_blueprint
+from lofar.sas.tmss.tmss.tmssapp.tasks import create_scheduling_unit_blueprint_and_tasks_and_subtasks_from_scheduling_unit_draft, create_scheduling_unit_blueprint_from_scheduling_unit_draft, update_task_blueprint_graph_from_draft, update_task_graph_from_specifications_doc, create_scheduling_unit_draft_from_observing_strategy_template, mark_task_blueprint_as_obsolete, cancel_task_blueprint, schedule_independent_subtasks_in_task_blueprint, create_scheduling_unit_draft_from_scheduling_unit_blueprint
 from lofar.sas.tmss.tmss.tmssapp.subtasks import schedule_subtask, cancel_subtask, wait_for_subtask_status
 from lofar.sas.tmss.test.test_utils import set_subtask_state_following_allowed_transitions
 from lofar.messaging.messagebus import BusListenerJanitor
+from lofar.sas.tmss.tmss.tmssapp.conversions import get_lofar2_stations
 
 logger = logging.getLogger(__name__)
 logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s', level=logging.INFO)
@@ -495,11 +496,67 @@ class SchedulingUnitBlueprintStateTest(unittest.TestCase):
         draft = create_scheduling_unit_draft_from_observing_strategy_template(strategy_template, scheduling_set=models.SchedulingSet.objects.create(**SchedulingSet_test_data()), specifications_doc_overrides=override_doc)
 
         # the template contains international stations and a custom group by default, both including DE605.
-        self.assertNotIn("DE605", draft.specifications_doc)
+        self.assertNotIn("DE605", str(draft.specifications_doc))
 
         # assert the draft specs contain exactly the specified station groups:
         self.assertEqual(draft.specifications_doc["tasks"]["Target Observation"]["specifications_doc"]["station_configuration"]["station_groups"], station_groups)
 
+    def test_remove_lofar2_station_while_create_scheduling_unit_draft_from_scheduling_unit_blueprint(self):   # TMSS-2637
+        # create scheduling unit with several LOFAR1 stations
+        strategy_template = models.SchedulingUnitObservingStrategyTemplate.get_latest(name="IM HBA - 1 Beam")
+        station_groups = [{"stations": ["CS002","CS003","CS004","CS005","DE605"], "max_nr_missing": 1}]
+        override_doc = {"tasks": {"Target Observation": {"specifications_doc": {"station_configuration": {"station_groups": station_groups}}}}}
+        draft = create_scheduling_unit_draft_from_observing_strategy_template(strategy_template, scheduling_set=models.SchedulingSet.objects.create(**SchedulingSet_test_data()), specifications_doc_overrides=override_doc)
+        blueprint = create_scheduling_unit_blueprint_and_tasks_and_subtasks_from_scheduling_unit_draft(draft)
+
+        # check that the blueprint contains a particular station and all its tasks validate
+        self.assertIn("DE605", str(blueprint.specifications_doc))
+        for task in blueprint.task_blueprints.all():
+            task.validate_specifications_doc()
+
+        lofar2_stations = get_lofar2_stations()
+        self.assertNotIn('DE605', lofar2_stations)
+
+        # mark that station LOFAR2
+        station_schema_template = models.CommonSchemaTemplate.get_latest(name="stations")
+        groups = station_schema_template.schema['definitions']['station_group']['anyOf']
+        for group in groups:
+            if group['title'].upper() == 'LOFAR2':
+                group['properties']['stations']['enum'][0].append('DE605')
+        station_schema_template.save()
+        get_lofar2_stations.cache_clear()
+        lofar2_stations = get_lofar2_stations()
+        self.assertIn('DE605', lofar2_stations)
+
+        # make sure that the previously blueprinted tasks don't validate any more
+        with self.assertRaises(RuleValidationException):
+            for task in blueprint.task_blueprints.all():
+                task.validate_specifications_doc()
+
+        # start a subtask so that the unit is advance enough to skip rule-based validation on specs read access
+        set_subtask_state_following_allowed_transitions(blueprint.subtasks.first(), 'started')
+        blueprint.refresh_from_db()
+        self.assertEqual(blueprint.status.value, models.SchedulingUnitStatus.Choices.OBSERVING.value)
+
+        # try to create a draft copy, which should fail due to the invalid specs
+        with self.assertRaises(ValidationException):
+            create_scheduling_unit_draft_from_scheduling_unit_blueprint(blueprint)
+
+        # try again, but this time remove lofar2 stations when creating the draft so that it can be blueprinted
+        draft_copy = create_scheduling_unit_draft_from_scheduling_unit_blueprint(blueprint, remove_lofar2_stations=True)
+        blueprint_from_copy = create_scheduling_unit_blueprint_and_tasks_and_subtasks_from_scheduling_unit_draft(draft_copy)
+
+        # check that the new blueprint does not contain the LOFAR2 station and all its tasks validate
+        self.assertNotIn("DE605", str(blueprint_from_copy.specifications_doc))
+        for task in blueprint_from_copy.task_blueprints.all():
+            task.validate_specifications_doc()
+
+        # revert the list of LOFAR2 stations, so other tests still work
+        for group in groups:
+            if group['title'].upper() == 'LOFAR2':
+                group['properties']['stations']['enum'][0].remove('DE605')
+        station_schema_template.save()
+        get_lofar2_stations.cache_clear()
 
 class TestFlatStations(unittest.TestCase):
     """
-- 
GitLab