From 2c2ca5301cb61a2b3070bd948426f3aa7d1b70b0 Mon Sep 17 00:00:00 2001
From: Jorrit Schaap <schaap@astron.nl>
Date: Mon, 10 Aug 2020 14:42:10 +0200
Subject: [PATCH] TMSS-272: update the  fields for tmss common json schema
 fields to the tmss host

---
 .../src/tmss/tmssapp/models/specification.py  | 50 +++++++++++++++++++
 .../tmss/tmssapp/serializers/scheduling.py    |  4 +-
 .../tmss/tmssapp/serializers/specification.py | 42 +++-------------
 3 files changed, 60 insertions(+), 36 deletions(-)

diff --git a/SAS/TMSS/src/tmss/tmssapp/models/specification.py b/SAS/TMSS/src/tmss/tmssapp/models/specification.py
index 743963c4edf..36368a34a8b 100644
--- a/SAS/TMSS/src/tmss/tmssapp/models/specification.py
+++ b/SAS/TMSS/src/tmss/tmssapp/models/specification.py
@@ -10,6 +10,7 @@ from django.db.models.expressions import RawSQL
 from django.db.models.deletion import ProtectedError
 from lofar.sas.tmss.tmss.tmssapp.validation import validate_json_against_schema
 from django.core.exceptions import ValidationError
+from django.urls import reverse as reverse_path
 from rest_framework import status
 import datetime
 
@@ -217,6 +218,55 @@ class Template(NamedCommon):
         # TODO: remove all <class>_unique_name_version UniqueConstraint's from the subclasses and replace by this line below when we start using django 3.0
         # constraints = [UniqueConstraint(fields=['name', 'version'], name='%(class)s_unique_name_version')]
 
+    @staticmethod
+    def update_tmss_common_json_schema_refs(schema, base_url: str=None):
+        '''return the given schema with all $ref fields updated so they point to the given base_url'''
+        if base_url is None:
+            # assume tmms is running locally
+            base_url = 'http://localhost:8000'
+
+        if isinstance(schema, dict):
+            updated_schema = {}
+            for key, value in schema.items():
+                if key == "$ref":
+                    if value.startswith('#'):
+                        # reference to local document, no need for http injection
+                        updated_schema[key] = value
+                    else:
+                        try:
+                            # deduct referred schema name and version from ref-value
+                            head, hash, tail = value.partition('#')
+                            head_parts = head.rstrip('/').split('/')
+                            schema_name = head_parts[-2]
+                            schema_version = head_parts[-1]
+                            tail = hash+tail
+
+                            # construct the common json schema path for this ref
+                            schema_path = reverse_path('get_common_json_schema', kwargs={'name': schema_name, 'version': schema_version})
+
+                            # and construct the proper ref url
+                            updated_schema[key] = base_url + schema_path + tail
+                        except:
+                            # aparently the reference is not conform the expected lofar common json schema path...
+                            # so, just accept the original value and assume that the user uploaded a proper schema
+                            updated_schema[key] = value
+                elif isinstance(value, dict):
+                    updated_schema[key] = Template.update_tmss_common_json_schema_refs(value, base_url)
+                elif isinstance(value, list):
+                    updated_schema[key] = [Template.update_tmss_common_json_schema_refs(item, base_url) for item in value]
+                else:
+                    updated_schema[key] = value
+            return updated_schema
+
+        if isinstance(schema, list):
+            return [Template.update_tmss_common_json_schema_refs(item, base_url) for item in schema]
+
+        return schema
+
+    def save(self, force_insert=False, force_update=False, using=None, update_fields=None):
+        self.schema = Template.update_tmss_common_json_schema_refs(self.schema)
+        super().save(force_insert, force_update, using, update_fields)
+
 
 # concrete models
 
diff --git a/SAS/TMSS/src/tmss/tmssapp/serializers/scheduling.py b/SAS/TMSS/src/tmss/tmssapp/serializers/scheduling.py
index 0f899c61de3..fe883ddde14 100644
--- a/SAS/TMSS/src/tmss/tmssapp/serializers/scheduling.py
+++ b/SAS/TMSS/src/tmss/tmssapp/serializers/scheduling.py
@@ -58,7 +58,7 @@ class DefaultSubtaskTemplateSerializer(RelationalHyperlinkedModelSerializer):
         fields = '__all__'
 
 
-class DataproductSpecificationsTemplateSerializer(RelationalHyperlinkedModelSerializer):
+class DataproductSpecificationsTemplateSerializer(AbstractTemplateSerializer):
     class Meta:
         model = models.DataproductSpecificationsTemplate
         fields = '__all__'
@@ -71,7 +71,7 @@ class DefaultDataproductSpecificationsTemplateSerializer(RelationalHyperlinkedMo
 
 
 
-class DataproductFeedbackTemplateSerializer(RelationalHyperlinkedModelSerializer):
+class DataproductFeedbackTemplateSerializer(AbstractTemplateSerializer):
     class Meta:
         model = models.DataproductFeedbackTemplate
         fields = '__all__'
diff --git a/SAS/TMSS/src/tmss/tmssapp/serializers/specification.py b/SAS/TMSS/src/tmss/tmssapp/serializers/specification.py
index 21d4615e48a..e21ae2ab9bd 100644
--- a/SAS/TMSS/src/tmss/tmssapp/serializers/specification.py
+++ b/SAS/TMSS/src/tmss/tmssapp/serializers/specification.py
@@ -7,7 +7,6 @@ from .. import models
 from .widgets import JSONEditorField
 from django.contrib.auth.models import User
 from django.core.exceptions import ImproperlyConfigured
-from django.urls import reverse as reverse_path
 from rest_framework import decorators
 import json
 
@@ -88,38 +87,13 @@ class TagsSerializer(RelationalHyperlinkedModelSerializer):
         model = models.Tags
         fields = '__all__'
 
-class JSONSchemaField(serializers.JSONField):
-    @staticmethod
-    def fix_refs(schema, base_url):
-        if isinstance(schema, dict):
-            updated_schema = {}
-            for key, value in schema.items():
-                if key == "$ref":
-                    if value.startswith('#') or value.startswith('http'):
-                        updated_schema[key] = value
-                    else:
-                        # inject base_url, so the $ref can be followed using a full url
-                        parts = value.split('/')
-                        schema_path = reverse_path('get_common_json_schema', kwargs={'name': parts[0], 'version': parts[1]})
-                        updated_schema[key] = base_url + schema_path + '/'.join(parts[2:])
-                elif isinstance(value, dict):
-                    updated_schema[key] = JSONSchemaField.fix_refs(value, base_url)
-                elif isinstance(value, list):
-                    updated_schema[key] = [JSONSchemaField.fix_refs(item, base_url) for item in value]
-                else:
-                    updated_schema[key] = value
-            return updated_schema
-
-        if isinstance(schema, list):
-            return [JSONSchemaField.fix_refs(item, base_url) for item in schema]
-
-        return schema
 
+class JSONSchemaField(serializers.JSONField):
     def to_representation(self, value):
-        base_url = "%s://%s" % (self.context['request'].scheme,
-                                self.context['request'].get_host())
+        '''make sure the common json schema $ref fields point to the correct host'''
+        base_url = "%s://%s" % (self.context['request'].scheme, self.context['request'].get_host())
+        return models.Template.update_tmss_common_json_schema_refs(value, base_url)
 
-        return JSONSchemaField.fix_refs(value, base_url)
 
 class AbstractTemplateSerializer(RelationalHyperlinkedModelSerializer):
     schema = JSONSchemaField()
@@ -134,7 +108,7 @@ class CommonSchemaTemplateSerializer(AbstractTemplateSerializer):
         fields = '__all__'
 
 
-class GeneratorTemplateSerializer(RelationalHyperlinkedModelSerializer):
+class GeneratorTemplateSerializer(AbstractTemplateSerializer):
     class Meta:
         model = models.GeneratorTemplate
         fields = '__all__'
@@ -146,7 +120,7 @@ class DefaultGeneratorTemplateSerializer(RelationalHyperlinkedModelSerializer):
         fields = '__all__'
 
 
-class SchedulingUnitTemplateSerializer(RelationalHyperlinkedModelSerializer):
+class SchedulingUnitTemplateSerializer(AbstractTemplateSerializer):
     class Meta:
         model = models.SchedulingUnitTemplate
         fields = '__all__'
@@ -158,7 +132,7 @@ class DefaultSchedulingUnitTemplateSerializer(RelationalHyperlinkedModelSerializ
         fields = '__all__'
 
 
-class TaskTemplateSerializer(RelationalHyperlinkedModelSerializer):
+class TaskTemplateSerializer(AbstractTemplateSerializer):
     class Meta:
         model = models.TaskTemplate
         fields = '__all__'
@@ -170,7 +144,7 @@ class DefaultTaskTemplateSerializer(RelationalHyperlinkedModelSerializer):
         fields = '__all__'
 
 
-class TaskRelationSelectionTemplateSerializer(RelationalHyperlinkedModelSerializer):
+class TaskRelationSelectionTemplateSerializer(AbstractTemplateSerializer):
     class Meta:
         model = models.TaskRelationSelectionTemplate
         fields = '__all__'
-- 
GitLab