# Copyright (C) 2012-2015  ASTRON (Netherlands Institute for Radio Astronomy)
# P.O. Box 2, 7990 AA Dwingeloo, The Netherlands
#
# This file is part of the LOFAR software suite.
# The LOFAR software suite is free software: you can redistribute it and/or
# modify it under the terms of the GNU General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# The LOFAR software suite is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with the LOFAR software suite. If not, see <http://www.gnu.org/licenses/>.

import json
import jsonschema
from copy import deepcopy
import requests
from datetime import datetime, timedelta

DEFAULT_MAX_SCHEMA_CACHE_AGE = timedelta(minutes=1)

def _extend_with_default(validator_class):
    """
    Extend the properties validation so that it adds missing properties with their default values (where one is defined
    in the schema).
    traverse down and add enclosed properties.
    see: <https://python-jsonschema.readthedocs.io/en/stable/faq/#why-doesn-t-my-schema-s-default-property-set-the-default-on-my-instance>
    """
    validate_properties = validator_class.VALIDATORS["properties"]

    def set_defaults(validator, properties, instance, schema):
        for property, subschema in properties.items():
            if "default" in subschema:
                instance.setdefault(property, subschema["default"])
            elif "type" not in subschema:
                # could be anything, probably a $ref.
                pass
            elif subschema["type"] == "object":
                # giving objects the {} default causes that default to be populated by the properties of the object
                instance.setdefault(property, {})
            elif subschema["type"] == "array":
                # giving arrays the [] default causes that default to be populated by the items of the array
                instance.setdefault(property, [])

        for error in validate_properties(
            validator, properties, instance, schema,
        ):
            yield error

    return jsonschema.validators.extend(
        validator_class, {"properties" : set_defaults},
    )


def _extend_with_required(validator_class):
    """
    Extend the required properties validation so that it adds missing required properties with their default values,
    (where one is defined in the schema).
    (Note: the check for required properties happens before property validation, so this is required even though the
           override in _extend_with_default would as well add the property.)
    see: <https://python-jsonschema.readthedocs.io/en/stable/faq/#why-doesn-t-my-schema-s-default-property-set-the-default-on-my-instance>
    """
    validate_required = validator_class.VALIDATORS["required"]

    def set_required_properties(validator, properties, instance, schema):
        for property in properties:
            subschema = schema['properties'].get(property, {})
            if "default" in subschema:
                instance.setdefault(property,  subschema["default"])
        for error in validate_required(
            validator, properties, instance, schema,
        ):
            yield error

    return jsonschema.validators.extend(
        validator_class, {"required" : set_required_properties},
    )

# define a custom validator that fills in properties before validation
_DefaultValidatingDraft6Validator = _extend_with_default(jsonschema.Draft6Validator)
_DefaultValidatingDraft6Validator = _extend_with_required(_DefaultValidatingDraft6Validator)

# storage for validators, for fast caching of ref resolved urls.
_schema_validators = {}
_schema__defaults_addding_validators = {}

def get_validator_for_schema(schema: dict, add_defaults: bool=False):
    '''get a json validator for the given schema.
    If the schema is already known in the cache by its $id, then the validator from the cached is return.
    This saves many many lookups and ref resolving.
    the 'add_defaults' parameter indicates if we want the validator to add defaults while validating or not.'''
    if isinstance(schema, str):
        schema = json.loads(schema)

    validators_cache = _schema__defaults_addding_validators if add_defaults else _schema_validators

    if '$id' in schema:
        if schema['$id'] not in validators_cache:
            validators_cache[schema['$id']] = _DefaultValidatingDraft6Validator(schema) if add_defaults else jsonschema.Draft6Validator(schema=schema)
        validator = validators_cache[schema['$id']]
    else:
        validator = _DefaultValidatingDraft6Validator(schema) if add_defaults else jsonschema.Draft6Validator(schema=schema)

    validator.schema = schema
    return validator

def get_default_json_object_for_schema(schema: str) -> dict:
    '''return a valid json object for the given schema with all properties with their default values'''
    return add_defaults_to_json_object_for_schema({}, schema)

def add_defaults_to_json_object_for_schema(json_object: dict, schema: str, cache: dict=None, max_cache_age: timedelta=DEFAULT_MAX_SCHEMA_CACHE_AGE) -> dict:
    '''return a copy of the json object with defaults filled in according to the schema for all the missing properties'''
    copy_of_json_object = deepcopy(json_object)

    # add a $schema to the json doc if needed
    if '$schema' not in copy_of_json_object and '$id' in schema:
        copy_of_json_object['$schema'] = schema['$id']

    # resolve $refs to fill in defaults for those, too
    schema = resolved_refs(schema, cache=cache, max_cache_age=max_cache_age)

    # run validator, which populates the properties with defaults.
    get_validator_for_schema(schema, add_defaults=True).validate(copy_of_json_object)
    return copy_of_json_object

def replace_host_in_urls(schema, new_base_url: str, keys=['$id', '$ref', '$schema']):
    '''return the given schema with all fields in the given keys which start with the given old_base_url updated so they point to the given new_base_url'''
    if isinstance(schema, dict):
        updated_schema = {}
        for key, value in schema.items():
            if key in keys:
                if isinstance(value,str) and (value.startswith('http://') or value.startswith('https://')) and 'json-schema.org' not in value:
                    try:
                        # deconstruct path from old url
                        head, anchor, tail = value.partition('#')
                        host, slash, path = head.lstrip('http://').lstrip('https://').partition('/')

                        # and reconstruct the proper new url
                        updated_schema[key] = (new_base_url.rstrip('/') + '/' + path + anchor + tail.rstrip('/')).replace(' ', '%20')
                    except:
                        # just accept the original value and assume that the user uploaded a proper schema
                        updated_schema[key] = value
                else:
                    updated_schema[key] = value
            else:
                updated_schema[key] = replace_host_in_urls(value, new_base_url, keys)
        return updated_schema

    if isinstance(schema, list):
        return [replace_host_in_urls(item, new_base_url, keys) for item in schema]

    return schema

def get_referenced_subschema(ref_url, cache: dict=None, max_cache_age: timedelta=DEFAULT_MAX_SCHEMA_CACHE_AGE):
    '''fetch the schema given by the ref_url, and get the sub-schema given by the #/ path in the ref_url'''
    # deduct referred schema name and version from ref-value
    head, anchor, tail = ref_url.partition('#')
    if isinstance(cache, dict) and head in cache:
        # use cached value
        referenced_schema, last_update_timestamp = cache[head]

        # refresh cache if outdated
        if datetime.utcnow() - last_update_timestamp > max_cache_age:
            referenced_schema = json.loads(requests.get(ref_url).text)
            cache[head] = referenced_schema, datetime.utcnow()
    else:
        # fetch url, and store in cache
        referenced_schema = json.loads(requests.get(ref_url).text)
        if isinstance(cache, dict):
            cache[head] = referenced_schema, datetime.utcnow()

    # extract sub-schema
    tail = tail.strip('/')
    if tail:
        parts = tail.split('/')
        for part in parts:
            referenced_schema = referenced_schema[part]

    return referenced_schema


def resolved_refs(schema, cache: dict=None, max_cache_age: timedelta=DEFAULT_MAX_SCHEMA_CACHE_AGE):
    '''return the given schema with all $ref fields replaced by the referred json (sub)schema that they point to.'''
    if cache is None:
        cache = {}

    if isinstance(schema, dict):
        updated_schema = {}
        keys = list(schema.keys())
        if "$ref" in keys and isinstance(schema['$ref'], str) and schema['$ref'].startswith('http'):
            keys.remove("$ref")
            referenced_subschema = get_referenced_subschema(schema['$ref'], cache=cache, max_cache_age=max_cache_age)
            updated_schema = resolved_refs(referenced_subschema, cache)

        for key in keys:
            updated_schema[key] = resolved_refs(schema[key], cache)
        return updated_schema

    if isinstance(schema, list):
        return [resolved_refs(item, cache) for item in schema]

    return schema

def get_refs(schema) -> set:
    '''return a set of all $refs in the schema'''
    refs = set()
    if isinstance(schema, dict):
        for key, value in schema.items():
            if key == "$ref":
                refs.add(value)
            else:
                refs.update(get_refs(value))

    if isinstance(schema, list):
        for item in schema:
            refs.update(get_refs(item))

    return refs


def validate_json_against_its_schema(json_object: dict):
    '''validate the give json object against its own schema (the URI/URL that its propery $schema points to)'''
    schema_url = json_object['$schema']
    response = requests.get(schema_url, headers={"Accept":"application/json"})
    if response.status_code == 200:
        return validate_json_against_schema(json_object, response.text)
    raise jsonschema.exceptions.ValidationError("Could not get schema from '%s'\n%s" % (schema_url, str(response.text)))

def validate_json_against_schema(json_string: str, schema: str):
    '''validate the given json_string against the given schema.
       If no exception if thrown, then the given json_string validates against the given schema.
       :raises SchemaValidationException if the json_string does not validate against the schema
     '''

    # ensure the given arguments are strings
    if type(json_string) != str:
        json_string = json.dumps(json_string)
    if type(schema) != str:
        schema = json.dumps(schema)

    # ensure the specification and schema are both valid json in the first place
    try:
        json_object = json.loads(json_string)
    except json.decoder.JSONDecodeError as e:
        raise jsonschema.exceptions.ValidationError("Invalid JSON: %s\n%s" % (str(e), json_string))

    try:
        schema_object = json.loads(schema)
    except json.decoder.JSONDecodeError as e:
        raise jsonschema.exceptions.ValidationError("Invalid JSON: %s\n%s" % (str(e), schema))

    # now do the actual validation
    try:
        validate_json_object_with_schema(json_object, schema_object)
    except jsonschema.ValidationError as e:
        raise jsonschema.exceptions.ValidationError(str(e))


def get_default_json_object_for_schema(schema: str) -> dict:
    """
    TMSS wrapper for TMSS 'add_defaults_to_json_object_for_schema'
    :param schema:
    :return: json_object with default values of the schema
    """
    data = add_defaults_to_json_object_for_schema({}, schema)
    if '$id' in schema:
        data['$schema'] = schema['$id']
    return data


def validate_json_object_with_schema(json_object, schema):
    """
    Validate the given json_object with schema
    """
    get_validator_for_schema(schema, add_defaults=False).validate(json_object)