From 00d3177e983597f7ac54ea4215a48e73fdba00a3 Mon Sep 17 00:00:00 2001
From: Jorrit Schaap <schaap@astron.nl>
Date: Fri, 29 May 2020 10:02:45 +0200
Subject: [PATCH] TMSS-206: just return a QuerySet for the properties
 successors and predecessors.

---
 .../src/tmss/tmssapp/models/scheduling.py     | 24 ++++++--------
 .../src/tmss/tmssapp/viewsets/scheduling.py   |  4 +--
 .../test/t_tmssapp_scheduling_django_API.py   | 32 +++++++++----------
 3 files changed, 27 insertions(+), 33 deletions(-)

diff --git a/SAS/TMSS/src/tmss/tmssapp/models/scheduling.py b/SAS/TMSS/src/tmss/tmssapp/models/scheduling.py
index 5dba109317e..48d9ca29ff7 100644
--- a/SAS/TMSS/src/tmss/tmssapp/models/scheduling.py
+++ b/SAS/TMSS/src/tmss/tmssapp/models/scheduling.py
@@ -7,7 +7,7 @@ import logging
 logger = logging.getLogger(__name__)
 
 from django.db.models import ForeignKey, CharField, DateTimeField, BooleanField, IntegerField, BigIntegerField, \
-    ManyToManyField, CASCADE, SET_NULL, PROTECT, UniqueConstraint
+    ManyToManyField, CASCADE, SET_NULL, PROTECT, UniqueConstraint, QuerySet
 from django.contrib.postgres.fields import ArrayField, JSONField
 from django.contrib.auth.models import User
 from .specification import AbstractChoice, BasicCommon, Template, NamedCommon # , <TaskBlueprint
@@ -174,8 +174,10 @@ class Subtask(BasicCommon):
             tobus.send(msg)
 
     @property
-    def successors_queryset(self):
-        '''return the connect successor subtask(s) as queryset (over which you can perform extended queries, or return via the serializers/viewsets)'''
+    def successors(self) -> QuerySet:
+        '''return the connect successor subtask(s) as queryset (over which you can perform extended queries, or return via the serializers/viewsets)
+           If you want the result, add .all() like so: my_subtask.successors.all()
+        '''
         # JS, 20200528: I couldn't make django do a "self-reference" query from the subtask table to the subtask table (via input, output), so I used plain SQL.
         return Subtask.objects.filter(id__in=RawSQL("SELECT successor_st.id FROM tmssapp_subtask as successor_st\n"
                                                     "INNER JOIN tmssapp_subtaskinput as st_input on st_input.subtask_id = successor_st.id\n"
@@ -183,24 +185,16 @@ class Subtask(BasicCommon):
                                                     "WHERE st_output.subtask_id = %s", params=[self.id]))
 
     @property
-    def successors(self):
-        '''return the connect successor subtask(s) as result of successors_queryset'''
-        return self.successors_queryset.all()
-
-    @property
-    def predecessors_queryset(self):
-        '''return the connect predecessor subtask(s) as queryset (over which you can perform extended queries, or return via the serializers/viewsets)'''
+    def predecessors(self) -> QuerySet:
+        '''return the connect predecessor subtask(s) as queryset (over which you can perform extended queries, or return via the serializers/viewsets)
+        If you want the result, add .all() like so: my_subtask.predecessors.all()
+        '''
         # JS, 20200528: I couldn't make django do a "self-reference" query from the subtask table to the subtask table (via input, output), so I used plain SQL.
         return Subtask.objects.filter(id__in=RawSQL("SELECT predecessor_st.id FROM tmssapp_subtask as predecessor_st\n"
                                                     "INNER JOIN tmssapp_subtaskoutput as st_output on st_output.subtask_id = predecessor_st.id\n"
                                                     "INNER JOIN tmssapp_subtaskinput as st_input on st_input.producer_id = st_output.id\n"
                                                     "WHERE st_input.subtask_id = %s", params=[self.id]))
 
-    @property
-    def predecessors(self):
-        '''return the connect predecessor subtask(s) as result of predecessors_queryset'''
-        return self.predecessors_queryset.all()
-
     def save(self, force_insert=False, force_update=False, using=None, update_fields=None):
         creating = self._state.adding  # True on create, False on update
 
diff --git a/SAS/TMSS/src/tmss/tmssapp/viewsets/scheduling.py b/SAS/TMSS/src/tmss/tmssapp/viewsets/scheduling.py
index 14df21b1ada..2ab0580648d 100644
--- a/SAS/TMSS/src/tmss/tmssapp/viewsets/scheduling.py
+++ b/SAS/TMSS/src/tmss/tmssapp/viewsets/scheduling.py
@@ -198,10 +198,10 @@ class SubtaskNestedViewSet(LOFARNestedViewSet):
             subtask = get_object_or_404(models.Subtask, pk=self.kwargs['subtask_id'])
 
             if 'successors' in self.request._request.path:
-                return subtask.successors_queryset
+                return subtask.successors
 
             if 'predecessors' in self.request._request.path:
-                return subtask.predecessors_queryset
+                return subtask.predecessors
 
 class SubtaskInputViewSet(LOFARViewSet):
     queryset = models.SubtaskInput.objects.all()
diff --git a/SAS/TMSS/test/t_tmssapp_scheduling_django_API.py b/SAS/TMSS/test/t_tmssapp_scheduling_django_API.py
index e87ea0ed89d..6260ec3fbbb 100755
--- a/SAS/TMSS/test/t_tmssapp_scheduling_django_API.py
+++ b/SAS/TMSS/test/t_tmssapp_scheduling_django_API.py
@@ -215,10 +215,10 @@ class SubtaskTest(unittest.TestCase):
         subtask1:models.Subtask = models.Subtask.objects.create(**Subtask_test_data())
         subtask2:models.Subtask = models.Subtask.objects.create(**Subtask_test_data())
 
-        self.assertEqual(set(), set(subtask1.predecessors))
-        self.assertEqual(set(), set(subtask2.predecessors))
-        self.assertEqual(set(), set(subtask1.successors))
-        self.assertEqual(set(), set(subtask2.successors))
+        self.assertEqual(set(), set(subtask1.predecessors.all()))
+        self.assertEqual(set(), set(subtask2.predecessors.all()))
+        self.assertEqual(set(), set(subtask1.successors.all()))
+        self.assertEqual(set(), set(subtask2.successors.all()))
 
     def test_Subtask_predecessors_and_successors_simple(self):
         subtask1:models.Subtask = models.Subtask.objects.create(**Subtask_test_data())
@@ -227,8 +227,8 @@ class SubtaskTest(unittest.TestCase):
         output1 = models.SubtaskOutput.objects.create(subtask=subtask1)
         models.SubtaskInput.objects.create(**SubtaskInput_test_data(subtask=subtask2, producer=output1))
 
-        self.assertEqual(subtask1, subtask2.predecessors[0])
-        self.assertEqual(subtask2, subtask1.successors[0])
+        self.assertEqual(subtask1, subtask2.predecessors.all()[0])
+        self.assertEqual(subtask2, subtask1.successors.all()[0])
 
     def test_Subtask_predecessors_and_successors_complex(self):
         subtask1:models.Subtask = models.Subtask.objects.create(**Subtask_test_data())
@@ -255,16 +255,16 @@ class SubtaskTest(unittest.TestCase):
         models.SubtaskInput.objects.create(**SubtaskInput_test_data(subtask=subtask5, producer=output3))
         models.SubtaskInput.objects.create(**SubtaskInput_test_data(subtask=subtask6, producer=output5))
 
-        self.assertEqual(set((subtask1, subtask2)), set(subtask3.predecessors))
-        self.assertEqual(set((subtask4, subtask5)), set(subtask3.successors))
-        self.assertEqual(set((subtask3,)), set(subtask4.predecessors))
-        self.assertEqual(set((subtask3,)), set(subtask5.predecessors))
-        self.assertEqual(set((subtask3,)), set(subtask1.successors))
-        self.assertEqual(set((subtask3,)), set(subtask2.successors))
-        self.assertEqual(set(), set(subtask1.predecessors))
-        self.assertEqual(set(), set(subtask2.predecessors))
-        self.assertEqual(set(), set(subtask4.successors))
-        self.assertEqual(set((subtask6,)), set(subtask5.successors))
+        self.assertEqual(set((subtask1, subtask2)), set(subtask3.predecessors.all()))
+        self.assertEqual(set((subtask4, subtask5)), set(subtask3.successors.all()))
+        self.assertEqual(set((subtask3,)), set(subtask4.predecessors.all()))
+        self.assertEqual(set((subtask3,)), set(subtask5.predecessors.all()))
+        self.assertEqual(set((subtask3,)), set(subtask1.successors.all()))
+        self.assertEqual(set((subtask3,)), set(subtask2.successors.all()))
+        self.assertEqual(set(), set(subtask1.predecessors.all()))
+        self.assertEqual(set(), set(subtask2.predecessors.all()))
+        self.assertEqual(set(), set(subtask4.successors.all()))
+        self.assertEqual(set((subtask6,)), set(subtask5.successors.all()))
 
 class DataproductTest(unittest.TestCase):
     def test_Dataproduct_gets_created_with_correct_creation_timestamp(self):
-- 
GitLab