From f08b2a623df064973a9aa7a0cfd241dddd8d8adc Mon Sep 17 00:00:00 2001
From: Jorrit Schaap <schaap@astron.nl>
Date: Fri, 3 Sep 2021 13:41:49 +0200
Subject: [PATCH] TMSS-917: ensure that each parent task_blueprint has exactly
 one primary child subtask; added and adapted tests

---
 .../src/tmss/tmssapp/models/scheduling.py     | 12 ++++--
 .../tmss/tmssapp/serializers/scheduling.py    |  2 +-
 SAS/TMSS/backend/src/tmss/tmssapp/subtasks.py |  2 +-
 SAS/TMSS/backend/test/t_scheduling.py         | 17 +++-----
 .../test/t_tmssapp_scheduling_django_API.py   | 43 +++++++++++++++++++
 .../test/tmss_test_data_django_models.py      |  3 +-
 SAS/TMSS/backend/test/tmss_test_data_rest.py  |  5 ++-
 7 files changed, 64 insertions(+), 20 deletions(-)

diff --git a/SAS/TMSS/backend/src/tmss/tmssapp/models/scheduling.py b/SAS/TMSS/backend/src/tmss/tmssapp/models/scheduling.py
index 60806b9b48e..085dd455eb0 100644
--- a/SAS/TMSS/backend/src/tmss/tmssapp/models/scheduling.py
+++ b/SAS/TMSS/backend/src/tmss/tmssapp/models/scheduling.py
@@ -148,7 +148,7 @@ class Subtask(BasicCommon, ProjectPropertyMixin, TemplateSchemaMixin):
     start_time = DateTimeField(null=True, help_text='Start this subtask at the specified time (NULLable).')
     stop_time = DateTimeField(null=True, help_text='Stop this subtask at the specified time (NULLable).')
     state = ForeignKey('SubtaskState', null=False, on_delete=PROTECT, related_name='task_states', help_text='Subtask state (see Subtask State Machine).')
-    primary = BooleanField(default=False, help_text='TRUE if this is the one-and-only primary subtask in a parent TaskBlueprint.')
+    primary = BooleanField(default=False, db_index=True, help_text='TRUE if this is the one-and-only primary subtask in a parent TaskBlueprint.')
     specifications_doc = JSONField(help_text='Final specifications, as input for the controller.')
     task_blueprint = ForeignKey('TaskBlueprint', null=True, on_delete=PROTECT, related_name='subtasks', help_text='The parent TaskBlueprint.') #TODO: be more strict with null=False
     specifications_template = ForeignKey('SubtaskTemplate', null=False, on_delete=PROTECT, help_text='Schema used for specifications_doc.')
@@ -159,9 +159,6 @@ class Subtask(BasicCommon, ProjectPropertyMixin, TemplateSchemaMixin):
     global_identifier = OneToOneField('SIPidentifier', null=False, editable=False, on_delete=PROTECT, help_text='The global unique identifier for LTA SIP.')
     path_to_project = 'task_blueprint__scheduling_unit_blueprint__draft__scheduling_set__project'
 
-    # class Meta(BasicCommon.Meta):
-    #     constraints = [UniqueConstraint(fields=['primary', 'task_blueprint'], name='subtask_unique_primary_subtask_within_parent_task')]
-
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
 
@@ -307,6 +304,13 @@ class Subtask(BasicCommon, ProjectPropertyMixin, TemplateSchemaMixin):
             if self.start_time is None:
                 raise SubtaskSchedulingException("Cannot schedule subtask id=%s when start time is 'None'." % (self.pk, ))
 
+        # ensure there is and will be exactly one primary subtask per parent task_blueprint
+        # quite a complex check, luckily we have a test for that.
+        nr_of_primary_siblings = 0 if self.task_blueprint is None else self.task_blueprint.subtasks.filter(primary=True).exclude(id=self.id).count()
+        if (creating and ((self.primary and nr_of_primary_siblings!=0) or (not self.primary and nr_of_primary_siblings==0))) or \
+           (not creating and ((self.primary and nr_of_primary_siblings!=0) or (not self.primary and nr_of_primary_siblings==0))):
+            raise ValidationError("There should be exactly one primary subtask per parent task_blueprint")
+
         try:
             super().save(force_insert, force_update, using, update_fields)
         except InternalError as db_error:
diff --git a/SAS/TMSS/backend/src/tmss/tmssapp/serializers/scheduling.py b/SAS/TMSS/backend/src/tmss/tmssapp/serializers/scheduling.py
index 9aea1d21fd1..04a943b9313 100644
--- a/SAS/TMSS/backend/src/tmss/tmssapp/serializers/scheduling.py
+++ b/SAS/TMSS/backend/src/tmss/tmssapp/serializers/scheduling.py
@@ -83,7 +83,7 @@ class SubtaskSerializer(DynamicRelationalHyperlinkedModelSerializer):
     subtask_type = serializers.StringRelatedField(source='specifications_template.type', label='subtask_type', read_only=True, help_text='The subtask type as defined in the specifications template, provided here to safe an addition lookup.')
     specifications_doc = JSONEditorField(schema_source='specifications_template.schema')
     duration = FloatDurationField(read_only=True)
-    primary = serializers.BooleanField(read_only=True) #primary field is and should only be set by tmss django server upon subtask creation
+    primary = serializers.BooleanField()
     input_dataproducts = serializers.HyperlinkedRelatedField(many=True, read_only=True, view_name='dataproduct-detail')
     output_dataproducts = serializers.HyperlinkedRelatedField(many=True, read_only=True, view_name='dataproduct-detail')
 
diff --git a/SAS/TMSS/backend/src/tmss/tmssapp/subtasks.py b/SAS/TMSS/backend/src/tmss/tmssapp/subtasks.py
index 3def90606c9..ab02a13dbce 100644
--- a/SAS/TMSS/backend/src/tmss/tmssapp/subtasks.py
+++ b/SAS/TMSS/backend/src/tmss/tmssapp/subtasks.py
@@ -93,7 +93,7 @@ def create_subtasks_from_task_blueprint(task_blueprint: TaskBlueprint) -> [Subta
                         subtasks.append(subtask)
                 except Exception as e:
                     logger.exception(e)
-                    raise SubtaskCreationException('Cannot create subtasks for task id=%s for its schema name=\'%s\' in generator %s' % (task_blueprint.pk, template_name, generator)) from e
+                    raise SubtaskCreationException('Cannot create subtasks for task id=%s for its schema name=\'%s\' in generator \'%s\'' % (task_blueprint.pk, template_name, generator.__name__)) from e
             return subtasks
         else:
             logger.error('Cannot create subtasks for task id=%s because no generator exists for its schema name=%s' % (task_blueprint.pk, template_name))
diff --git a/SAS/TMSS/backend/test/t_scheduling.py b/SAS/TMSS/backend/test/t_scheduling.py
index 23e8b120326..f26129da576 100755
--- a/SAS/TMSS/backend/test/t_scheduling.py
+++ b/SAS/TMSS/backend/test/t_scheduling.py
@@ -162,8 +162,7 @@ class SchedulingTest(unittest.TestCase):
                                                      start_time=datetime.utcnow()+timedelta(minutes=5),
                                                      task_blueprint_url=task_blueprint['url'])
             subtask = test_data_creator.post_data_and_get_response_as_json_object(subtask_data, '/subtask/')
-            test_data_creator.post_data_and_get_url(test_data_creator.SubtaskOutput(subtask_url=subtask['url'],
-                                                                                    task_blueprint_url=task_blueprint['url']), '/subtask_output/')
+            test_data_creator.post_data_and_get_url(test_data_creator.SubtaskOutput(subtask_url=subtask['url']), '/subtask_output/')
 
             client.set_subtask_status(subtask['id'], 'defined')
             return subtask
@@ -384,8 +383,7 @@ class SchedulingTest(unittest.TestCase):
                                                          cluster_url=cluster_url,
                                                          task_blueprint_url=obs_task_blueprint['url'])
             obs_subtask = test_data_creator.post_data_and_get_response_as_json_object(obs_subtask_data, '/subtask/')
-            obs_subtask_output_url = test_data_creator.post_data_and_get_url(test_data_creator.SubtaskOutput(subtask_url=obs_subtask['url'],
-                                                                                                             task_blueprint_url=obs_task_blueprint['url']), '/subtask_output/')
+            obs_subtask_output_url = test_data_creator.post_data_and_get_url(test_data_creator.SubtaskOutput(subtask_url=obs_subtask['url']), '/subtask_output/')
             test_data_creator.post_data_and_get_url(test_data_creator.Dataproduct(**dataproduct_properties, subtask_output_url=obs_subtask_output_url), '/dataproduct/')
 
             # now create the pipeline...
@@ -403,8 +401,7 @@ class SchedulingTest(unittest.TestCase):
 
             # ...and connect it to the observation
             test_data_creator.post_data_and_get_url(test_data_creator.SubtaskInput(subtask_url=pipe_subtask['url'], subtask_output_url=obs_subtask_output_url), '/subtask_input/')
-            test_data_creator.post_data_and_get_url(test_data_creator.SubtaskOutput(subtask_url=pipe_subtask['url'],
-                                                                                    task_blueprint_url=pipe_task_blueprint['url']), '/subtask_output/')
+            test_data_creator.post_data_and_get_url(test_data_creator.SubtaskOutput(subtask_url=pipe_subtask['url']), '/subtask_output/')
 
             for predecessor in client.get_subtask_predecessors(pipe_subtask['id']):
                 for state in ('defined', 'scheduling', 'scheduled', 'starting', 'started', 'finishing', 'finished'):
@@ -495,13 +492,13 @@ class SchedulingTest(unittest.TestCase):
             ingest_subtask_data = test_data_creator.Subtask(specifications_template_url=ingest_subtask_template['url'],
                                                           specifications_doc=ingest_spec,
                                                           task_blueprint_url=obs_subtask['task_blueprint'],
+                                                          primary=False,
                                                           cluster_url=cluster_url)
             ingest_subtask = test_data_creator.post_data_and_get_response_as_json_object(ingest_subtask_data, '/subtask/')
 
             # ...and connect it to the observation
             test_data_creator.post_data_and_get_url(test_data_creator.SubtaskInput(subtask_url=ingest_subtask['url'], subtask_output_url=obs_subtask_output_url), '/subtask_input/')
-            test_data_creator.post_data_and_get_url(test_data_creator.SubtaskOutput(subtask_url=ingest_subtask['url'],
-                                                                                    task_blueprint_url=obs_subtask['task_blueprint']), '/subtask_output/')  # our subtask here has only one known related task
+            test_data_creator.post_data_and_get_url(test_data_creator.SubtaskOutput(subtask_url=ingest_subtask['url']), '/subtask_output/')  # our subtask here has only one known related task
 
             for predecessor in client.get_subtask_predecessors(ingest_subtask['id']):
                 for state in ('defined', 'scheduling', 'scheduled', 'starting', 'started', 'finishing', 'finished'):
@@ -683,9 +680,7 @@ class SAPTest(unittest.TestCase):
                                                      stop_time=datetime.utcnow() + timedelta(minutes=15))
             subtask = test_data_creator.post_data_and_get_response_as_json_object(subtask_data, '/subtask/')
             subtask_id = subtask['id']
-            test_data_creator.post_data_and_get_url(test_data_creator.SubtaskOutput(subtask_url=subtask['url'],
-                                                                                    task_blueprint_url=task_blueprint['url']),
-                                                    '/subtask_output/')
+            test_data_creator.post_data_and_get_url(test_data_creator.SubtaskOutput(subtask_url=subtask['url']), '/subtask_output/')
 
             subtask_model = models.Subtask.objects.get(id=subtask_id)
             self.assertEqual(0, subtask_model.output_dataproducts.values('sap').count())
diff --git a/SAS/TMSS/backend/test/t_tmssapp_scheduling_django_API.py b/SAS/TMSS/backend/test/t_tmssapp_scheduling_django_API.py
index 8208d564f31..90268af04eb 100755
--- a/SAS/TMSS/backend/test/t_tmssapp_scheduling_django_API.py
+++ b/SAS/TMSS/backend/test/t_tmssapp_scheduling_django_API.py
@@ -334,6 +334,49 @@ class SubtaskTest(unittest.TestCase):
             models.Subtask = models.Subtask.objects.create(**test_data)
         self.assertIn('popular_name', str(context.exception))
 
+    def test_unique_primary_in_parent_task(self):
+        # create a parent task_blueprint
+        task_blueprint = models.TaskBlueprint.objects.create(**TaskBlueprint_test_data())
+
+        data = Subtask_test_data(task_blueprint=task_blueprint)
+
+        # creating a single child SubTask with primary=False should raise
+        with self.assertRaises(ValidationError):
+            data['primary'] = False
+            models.Subtask.objects.create(**data)
+        task_blueprint.refresh_from_db()
+        self.assertEqual(0, task_blueprint.subtasks.count())
+
+        # creating a single child SubTask with primary=True should work
+        data['primary'] = True
+        subtask1 = models.Subtask.objects.create(**data)
+        task_blueprint.refresh_from_db()
+        self.assertEqual(1, task_blueprint.subtasks.count())
+
+        # adding a second child SubTask with primary=False should work (there is still one unique primary subtask)
+        data['primary'] = False
+        subtask2 = models.Subtask.objects.create(**data)
+        task_blueprint.refresh_from_db()
+        self.assertEqual(2, task_blueprint.subtasks.count())
+
+        # adding a third child SubTask with primary=True should fail
+        with self.assertRaises(ValidationError):
+            data['primary'] = True
+            models.Subtask.objects.create(**data)
+        task_blueprint.refresh_from_db()
+        self.assertEqual(2, task_blueprint.subtasks.count())
+
+        # updating the first child SubTask to primary=False should fail
+        with self.assertRaises(ValidationError):
+            subtask1.primary = False
+            subtask1.save()
+
+        # updating the second child SubTask to primary=True should fail
+        with self.assertRaises(ValidationError):
+            subtask2.primary = True
+            subtask2.save()
+
+
 
 class DataproductTest(unittest.TestCase):
     def test_Dataproduct_gets_created_with_correct_creation_timestamp(self):
diff --git a/SAS/TMSS/backend/test/tmss_test_data_django_models.py b/SAS/TMSS/backend/test/tmss_test_data_django_models.py
index a2942421cd8..d731f84b945 100644
--- a/SAS/TMSS/backend/test/tmss_test_data_django_models.py
+++ b/SAS/TMSS/backend/test/tmss_test_data_django_models.py
@@ -403,7 +403,7 @@ def SubtaskInput_test_data(subtask: models.Subtask=None, producer: models.Subtas
 
 def Subtask_test_data(subtask_template: models.SubtaskTemplate=None,
                       specifications_doc: dict=None, start_time=None, stop_time=None, cluster=None, state=None,
-                      raw_feedback=None, task_blueprint: models.TaskBlueprint=None) -> dict:
+                      raw_feedback=None, task_blueprint: models.TaskBlueprint=None, primary:bool=True) -> dict:
 
     if subtask_template is None:
         subtask_template = models.SubtaskTemplate.objects.create(**SubtaskTemplate_test_data())
@@ -433,6 +433,7 @@ def Subtask_test_data(subtask_template: models.SubtaskTemplate=None,
              "tags": ["TMSS", "TESTING"],
              "cluster": cluster,
              "raw_feedback": raw_feedback,
+             "primary": primary,
              "global_identifier": models.SIPidentifier.objects.create(source="TMSS")}
 
 def Dataproduct_test_data(producer: models.SubtaskOutput=None,
diff --git a/SAS/TMSS/backend/test/tmss_test_data_rest.py b/SAS/TMSS/backend/test/tmss_test_data_rest.py
index 6d550ac2d05..4f74cf21029 100644
--- a/SAS/TMSS/backend/test/tmss_test_data_rest.py
+++ b/SAS/TMSS/backend/test/tmss_test_data_rest.py
@@ -645,7 +645,7 @@ class TMSSRESTTestDataCreator():
             return self._cluster_url
 
 
-    def Subtask(self, cluster_url=None, task_blueprint_url=None, specifications_template_url=None, specifications_doc=None, state:str="defining", start_time: datetime=None, stop_time: datetime=None, raw_feedback:str =None):
+    def Subtask(self, cluster_url=None, task_blueprint_url=None, specifications_template_url=None, specifications_doc=None, state:str="defining", start_time: datetime=None, stop_time: datetime=None, raw_feedback:str =None, primary: bool=True):
         if cluster_url is None:
             cluster_url = self.cached_cluster_url
     
@@ -678,7 +678,8 @@ class TMSSRESTTestDataCreator():
                 "specifications_template": specifications_template_url,
                 "tags": ["TMSS", "TESTING"],
                 "cluster": cluster_url,
-                "raw_feedback": raw_feedback}
+                "raw_feedback": raw_feedback,
+                "primary": primary}
 
     @property
     def cached_subtask_url(self):
-- 
GitLab