Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
tmss_test_environment_unittest_setup.py 11.20 KiB
#!/usr/bin/env python3

# Copyright (C) 2018    ASTRON (Netherlands Institute for Radio Astronomy)
# P.O. Box 2, 7990 AA Dwingeloo, The Netherlands
#
# This file is part of the LOFAR software suite.
# The LOFAR software suite is free software: you can redistribute it and/or
# modify it under the terms of the GNU General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# The LOFAR software suite is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.    See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with the LOFAR software suite. If not, see <http://www.gnu.org/licenses/>.

'''
By importing this helper module in your unittest module you get a TMSSTestDatabaseInstance
which is automatically destroyed at the end of the unittest session.
'''

import logging
logger = logging.getLogger(__name__)

# before we import any django modules the DJANGO_SETTINGS_MODULE, TMSS_LDAPCREDENTIALS and TMSS_DBCREDENTIALS need to be known/set.
# import and start an isolated TMSSTestEnvironment (with fresh database and attached django and ldap server on free ports)
# this automagically sets the required  DJANGO_SETTINGS_MODULE, TMSS_LDAPCREDENTIALS and TMSS_DBCREDENTIALS envvars.
from lofar.sas.tmss.test.test_environment import TMSSTestEnvironment
tmss_test_env = TMSSTestEnvironment()
try:
    tmss_test_env.start()
except Exception as e:
    logger.exception(str(e))
    tmss_test_env.stop()
    exit(1)

# tell unittest to stop (and automagically cleanup) the test database once all testing is done.
def tearDownModule():
    tmss_test_env.stop()


################################################################################################
# the methods below can be used to to HTTP REST calls to the django server and check the results
################################################################################################

import json
import requests
import datetime
AUTH = requests.auth.HTTPBasicAuth(tmss_test_env.ldap_server.dbcreds.user, tmss_test_env.ldap_server.dbcreds.password)
BASE_URL = tmss_test_env.django_server.url[:-1] if tmss_test_env.django_server.url.endswith('/')  else tmss_test_env.django_server.url
OIDC_URL = tmss_test_env.django_server.oidc_url[:-1] if tmss_test_env.django_server.oidc_url.endswith('/') else tmss_test_env.django_server.oidc_url
from lofar.sas.tmss.test.test_utils import assertDataWithUrls
import lofar.sas.tmss.tmss.settings as TMSS_SETTINGS

# by default we assert on requests taking longer than this timeout
DEFAULT_REQUEST_TIMEOUT=10

def _call_API_and_assert_expected_response(test_instance, url, call, data, expected_code, expected_content, auth=AUTH, timeout:float=DEFAULT_REQUEST_TIMEOUT):
    """
    Call API method on the provided url and assert the expected code is returned and the expected content is in the response content
    :return: response as dict. This either contains the data of an entry or error details. If JSON cannot be parsed, return string.
    """
    _start_request_timestamp = datetime.datetime.utcnow()
    if call == 'PUT':
        response = requests.put(url, json=data, auth=auth, timeout=timeout)
    elif call == 'POST':
        response = requests.post(url, json=data, auth=auth, timeout=timeout)
    elif call == 'GET':
        response = requests.get(url, auth=auth, allow_redirects=False)
    elif call == 'PATCH':
        response = requests.patch(url, json=data, auth=auth, timeout=timeout)
    elif call == 'DELETE':
        response = requests.delete(url, auth=auth, timeout=timeout)
    else:
        raise ValueError("The provided call '%s' is not a valid API method choice" % call)

    _stop_request_timestamp = datetime.datetime.utcnow()
    _request_duration = (_stop_request_timestamp - _start_request_timestamp).total_seconds()

    if _request_duration > timeout:
        raise TimeoutError("request to '%s' did not respond within the allowed %.3f[s]. Actual: %.3f[s]" % (url, timeout, _request_duration))

    if response.status_code != expected_code:
        logger.error("!!! Unexpected: [%s] - %s %s: %s", test_instance.id(), call, url, response.content.decode('utf-8').strip())
    test_instance.assertEqual(response.status_code, expected_code)

    content = response.content.decode('utf-8')

    from django.db import models

    if response.status_code in range(200, 300) and expected_content is not None:
        r_dict = json.loads(content)
        for key, value in expected_content.items():
            if key not in r_dict.keys():
                logger.error('!!! Missing key: %s in %s', key, r_dict.keys())
            test_instance.assertIn(key, r_dict.keys())
            if isinstance(value, models.Model):
                value = str(value.pk)
                value = value.replace(' ', '%20')
                if str(value) not in r_dict[key]:
                    logger.error('!!! Unexpected value of key=%s: expected=%s got=%s', key, value, r_dict[key])
                test_instance.assertIn(str(value), r_dict[key])
            elif type(value) is list:
                test_instance.assertEqual(sorted(value), sorted(r_dict[key]), msg="lists differ for key=%s"%key) # compare lists independent of ordering
            elif isinstance(value, datetime.datetime):
                # URL (r_dict[key]) is string but the test_data object (value) is datetime format, convert latter to string format to compare
                test_instance.assertEqual(value.isoformat(), r_dict[key])
            elif isinstance(value, dict):
                # only look for expected (sub)keys. More key/value pairs in the response dict are allowed.
                for sub_key, sub_value in value.items():
                    if sub_key not in ('$schema', '$id'):
                        test_instance.assertEqual(sub_value, r_dict[key][sub_key])
            else:
                test_instance.assertEqual(value, r_dict[key])
        return r_dict

    try:
        return json.loads(content)
    except:
        return content


def PUT_and_assert_expected_response(test_instance, url, data, expected_code, expected_content, auth=AUTH, timeout:float=DEFAULT_REQUEST_TIMEOUT):
    """
    PUT data on url and assert the expected code is returned and the expected content is in the response content
    """
    r_dict = _call_API_and_assert_expected_response(test_instance, url, 'PUT', data, expected_code, expected_content, auth=auth, timeout=timeout)
    return r_dict


def POST_and_assert_expected_response(test_instance, url, data, expected_code, expected_content, auth=AUTH, timeout:float=DEFAULT_REQUEST_TIMEOUT):
    """
    POST data on url and assert the expected code is returned and the expected content is in the response content
    :return: response dict
    """
    r_dict = _call_API_and_assert_expected_response(test_instance, url, 'POST', data, expected_code, expected_content, auth=auth, timeout=timeout)
    return r_dict


def GET_and_assert_equal_expected_code(test_instance, url, expected_code, auth=AUTH, timeout:float=DEFAULT_REQUEST_TIMEOUT):
    """
    GET from url and assert the expected code is returned and the expected content is in the response content
    """
    r_dict = _call_API_and_assert_expected_response(test_instance, url, 'GET', {}, expected_code, None, auth=auth, timeout=timeout)
    return r_dict


def GET_and_assert_in_expected_response_result_list(test_instance, url, expected_content, expected_nbr_results, expected_id=None, auth=AUTH, timeout:float=DEFAULT_REQUEST_TIMEOUT):
    """
    GET from url and assert the expected code is returned and the expected content is in the response content
    Use this check when multiple results (list) are returned
    """
    r_dict = _call_API_and_assert_expected_response(test_instance, url, 'GET', {}, 200, None, auth=auth, timeout=timeout)
    page_size = TMSS_SETTINGS.REST_FRAMEWORK.get('PAGE_SIZE')
    if page_size is not None and expected_nbr_results > page_size:
        logger.warning("Limited result length due to pagination setting (%d)", page_size)
        test_instance.assertEqual(page_size, len(r_dict["results"]))
        test_instance.assertEqual(page_size, len(r_dict["results"]))
        test_instance.assertEqual(expected_nbr_results, r_dict["count"])
        test_instance.assertNotEqual(None, r_dict['next'])
        url_check = False
    else:
        test_instance.assertEqual(expected_nbr_results, len(r_dict["results"]))
        test_instance.assertEqual(r_dict["count"], len(r_dict["results"]))
        test_instance.assertEqual(None, r_dict['next'])
        url_check = True

    for item in r_dict["results"]:
        for key in expected_content.keys():
            test_instance.assertIn(key, item.keys())

    if url_check:
        # Find the expected id in result list if parameter is given (of curse for just one it does not make sense)
        # There was an 'old' assumption that the last one should taken, but that is not reliable
        if expected_id is not None:
            for idx in range(0, expected_nbr_results):
                if r_dict['results'][idx]['id'] == expected_id:
                    expected_idx = idx
                    break
        else:
            # this is the 'old' assumption that last object added will also be the last one in the result dict
            expected_idx = expected_nbr_results-1
        assertDataWithUrls(test_instance, r_dict['results'][expected_idx], expected_content)
    return r_dict


def GET_OK_and_assert_equal_expected_response(test_instance, url, expected_content, auth=AUTH, timeout:float=DEFAULT_REQUEST_TIMEOUT):
    """
    GET from url and assert the expected code is returned and the expected content is equal the response content
    assertDataWithUrls is already checked in _call_API_and_assert_expected_response
    """
    r_dict = _call_API_and_assert_expected_response(test_instance, url, 'GET', {}, 200, expected_content, auth=auth, timeout=timeout)
    #     assertDataWithUrls(test_instance, r_dict, expected_content)
    return r_dict


def PATCH_and_assert_expected_response(test_instance, url, data, expected_code, expected_content, auth=AUTH, timeout:float=DEFAULT_REQUEST_TIMEOUT):
    """
    POST data on url and assert the provided values have changed based on the server response.
    :return: url for new item
    """
    r_dict = _call_API_and_assert_expected_response(test_instance, url, 'PATCH', data, expected_code, expected_content, auth=auth, timeout=timeout)
    return r_dict


def DELETE_and_assert_gone(test_instance, url, auth=AUTH, timeout:float=DEFAULT_REQUEST_TIMEOUT):
    """
    DELETE item at provided url and assert that the request was accepted by the server
    :return: url for new item
    """
    response = requests.delete(url, auth=auth, timeout=timeout)
    if response.status_code != 204:
        logger.error("!!! Unexpected: [%s] - %s %s: %s", test_instance.id(), 'DELETE', url, response.content)
    test_instance.assertEqual(response.status_code, 204)

    response = requests.get(url, auth=auth, timeout=timeout)
    if response.status_code != 404:
        logger.error("!!! Unexpected: [%s] - %s %s: %s", test_instance.id(), 'GET', url, response.content)
    test_instance.assertEqual(response.status_code, 404)