#!/usr/bin/python

# Copyright (C) 2012-2015    ASTRON (Netherlands Institute for Radio Astronomy)
# P.O. Box 2, 7990 AA Dwingeloo, The Netherlands
#
# This file is part of the LOFAR software suite.
# The LOFAR software suite is free software: you can redistribute it and/or
# modify it under the terms of the GNU General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# The LOFAR software suite is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.    See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with the LOFAR software suite. If not, see <http://www.gnu.org/licenses/>.

# $Id:  $
import unittest
import testing.postgresql
import psycopg2
import os
from datetime import datetime, timedelta
from dateutil import parser
import logging
import logging
import pprint

logger = logging.getLogger(__name__)

try:
    import mock
except ImportError as e:
    print str(e)
    print 'Please install python package mock: sudo pip install mock'
    exit(3)  # special lofar test exit code: skipped test

try:
    import testing.postgresql
except ImportError as e:
    print str(e)
    print 'Please install python package testing.postgresql: sudo pip install testing.postgresql'
    exit(3)  # special lofar test exit code: skipped test

from lofar.common.dbcredentials import Credentials
from lofar.sas.resourceassignment.database.radb import RADatabase
from lofar.common.postgres import PostgresListener


class ResourceAssignmentDatabaseTest(unittest.TestCase):
    # todo: test shared db to improve test speed
    # share the generated database for faster tests
    # Postgresql = testing.postgresql.PostgresqlFactory(cache_initialized_db=True)

    # These are applied in given order to set up test db
    # Note: cannot use create_and_populate_database.sql since '\i' is not understood by cursor.execute()
    sql_basepath = os.environ['LOFARROOT'] + "/share/radb/sql/"
    sql_createdb_files = ["create_database.sql",
                          "add_resource_allocation_statics.sql",
                          "add_virtual_instrument.sql",
                          "add_notifications.sql",
                          "add_functions_and_triggers.sql" ]

    def setUp(self):
        # set up postgres database
        logger.info('setting up test RA database...')

        self.postgresql = testing.postgresql.Postgresql()
        # self.Postgresql()
        self.connection = psycopg2.connect(**self.postgresql.dsn())

        database_credentials = Credentials()
        database_credentials.host = self.postgresql.dsn()['host']
        database_credentials.database = self.postgresql.dsn()['database']
        database_credentials.port = self.postgresql.dsn()['port']
        database_credentials.user = 'resourceassignment'
        database_credentials.password = 'blabla'  # cannot be empty...

        # Note: NOSUPERUSER currently raises "permission denied for schema virtual_instrument"
        # Maybe we want to sort out user creation and proper permissions in the sql scripts?
        query = "CREATE USER %s WITH SUPERUSER PASSWORD '%s'" % (
                database_credentials.user,
                database_credentials.password)
        self._execute_query(query)

        for sql_file in self.sql_createdb_files:
            sql_path = os.path.join(self.sql_basepath, sql_file)
            with open(sql_path) as sql:  # cursor.execute() does not accept '\i'
                self._execute_query(sql.read())

        logger.info('...finished setting up test RA database')

        # reconnect with useradministration role for tests
        self.connection.close()
        self.connection = psycopg2.connect(host=database_credentials.host,
                                           user=database_credentials.user,
                                           password=database_credentials.password,
                                           dbname=database_credentials.database,
                                           port=database_credentials.port)

        # set up PostgresListener for notifications:
        self.listener = PostgresListener(host=database_credentials.host,
                                         username=database_credentials.user,
                                         password=database_credentials.password,
                                         database=database_credentials.database,
                                         port=database_credentials.port)

        # set up radb python module
        self.radb = RADatabase(database_credentials, log_queries=True)

    def tearDown(self):
        logger.info('removing test RA database...')
        self.connection.close()
        self.postgresql.stop()
        logger.info('removed test RA database')
        # self.Postgresql.clear_cache() # todo: use this when using shared db instead of stop(), or remove.

    def _execute_query(self, query, fetch=False):
        cursor = self.connection.cursor()
        cursor.execute(query)
        ret = None
        if fetch:
            ret = cursor.fetchall()
        cursor.close()
        self.connection.commit()
        return ret

    # --- tests start here

    #
    # integrity tests of postgres database itself
    #
    # Note: These are meant to make sure the setup generally works and all sql scripts were applied.
    # I don't see much benefit in full coverage here since it should be all be tested through RADataBase functionality.
    # Of course new tests can be added here where db functionality like triggers should be tested separately from the
    # Python part of the job.

    # database created?
    def test_select_tables_contains_tables_for_each_schema(self):
        query = "SELECT table_schema,table_name FROM information_schema.tables"
        fetch = self._execute_query(query, fetch=True)
        self.assertTrue('resource_allocation' in str(fetch))
        self.assertTrue('resource_monitoring' in str(fetch))
        self.assertTrue('virtual_instrument' in str(fetch))

    # resource allocation_statics there?
    def test_select_task_types_contains_obervation(self):
        query = "SELECT * FROM resource_allocation.task_type"
        fetch = self._execute_query(query, fetch=True)
        self.assertTrue('observation' in str(fetch))

    # virtual instrument there?
    def test_select_virtualinstrument_units_contain_rcuboard(self):
        query = "SELECT * FROM virtual_instrument.unit"
        fetch = self._execute_query(query, fetch=True)
        self.assertTrue('rcu_board' in str(fetch))

    def _insert_test_spec(self,
                          starttime='2017-05-10 13:00:00',
                          endtime='2017-05-10 14:00:00',
                          content='testcontent',
                          cluster='CEP4'):
        query = "INSERT INTO resource_allocation.specification (starttime, endtime, content, cluster) " \
                "VALUES ('%s', '%s', '%s', '%s') RETURNING id" % (starttime, endtime, content, cluster)
        res = self._execute_query(query, fetch=True)
        return res[0][0]

    def test_insert_specification_creates_new_entry(self):
        # insert spec
        content = 'testcontent'
        ident = self._insert_test_spec(content=content)

        # check it is there
        query = "SELECT content FROM resource_allocation.specification WHERE id=%s" % ident
        fetch = self._execute_query(query, fetch=True)
        self.assertTrue(content in str(fetch))

    def test_update_specification_changes_entry(self):
        # insert spec
        ident = self._insert_test_spec()

        # update existing spec content
        newcontent = 'testcontent_new'
        query = "UPDATE resource_allocation.specification SET (content) = ('%s')" % newcontent
        self._execute_query(query)

        # check updated content
        query = "SELECT content FROM resource_allocation.specification WHERE id=%s" % ident
        fetch = self._execute_query(query, fetch=True)
        self.assertTrue(newcontent in str(fetch))

    def test_delete_specification(self):
        # insert spec
        content = 'deletecontent'
        ident = self._insert_test_spec(content=content)

        # make sure it's there
        query = "SELECT content FROM resource_allocation.specification WHERE id=%s" % ident
        fetch = self._execute_query(query, fetch=True)
        self.assertTrue(content in str(fetch))

        # delete testspec again
        query = "DELETE FROM resource_allocation.specification WHERE id = %s" % ident
        self._execute_query(query)

        # make sure it's gone
        query = "SELECT content FROM resource_allocation.specification WHERE id=%s" % ident
        fetch = self._execute_query(query, fetch=True)
        self.assertFalse(content in str(fetch))

    # triggers in place?
    def test_insert_specification_swaps_startendtimes_if_needed(self):
        #when inserting spec with start>endtime, should raise error
        with self.assertRaises(psycopg2.InternalError) as context:
            # insert spec
            starttime = '2017-05-10 12:00:00'
            endtime = '2017-05-10 10:00:00'
            ident = self._insert_test_spec(starttime=starttime, endtime=endtime)

    # notifications in place?
    def test_insert_task_triggers_notification(self):
        # insert specification to not raise INtegrityError
        ident = self._insert_test_spec()

        # listen on notification
        cursor = self.connection.cursor()
        cursor.execute("LISTEN %s;", (psycopg2.extensions.AsIs('task_insert'),))

        # todo: fix this and use this instead to listen for notifications.
        # todo: ...Problem: For some reason callback function is not called.
        # set up listener in a way we can check it was called
        # callback = mock.Mock()
        # callback.listen.return_value = 42
        # self.listener.subscribe('task_insert', callback.listen)

        # trigger notification
        query = "INSERT INTO resource_allocation.task (mom_id, otdb_id, status_id, type_id, specification_id)" \
                "VALUES (%s, %s, %s, %s, %s)" % (1, 1, 200, 0, ident)
        self._execute_query(query)

        # wait for notification
        notification = ''
        self.connection.poll()
        while self.connection.notifies:
            try:
                notification = self.connection.notifies.pop(0)
                break
            except Exception:
                pass

        self.assertTrue('task_insert' in str(notification))

        # todo: fix listener and use instead of polling:
        # callback.listen.assert_called()

    #
    # radb functionality tests
    #
    #

    def test_getTaskStatuses_contains_scheduled(self):
        stat = self.radb.getTaskStatuses()
        self.assertTrue('scheduled' in str(stat))

    def test_task_and_claim_conflicts(self):
        # for testing purposous let's give CEP4 storage a total size of 100
        self.assertTrue(self.radb.updateResourceAvailability(117, available_capacity=100, total_capacity=100))
        self.assertEqual(100, self.radb.getResources(117, include_availability=True)[0]['total_capacity'])

        now = datetime.utcnow()
        now -= timedelta(seconds=now.second, microseconds=now.microsecond)

        result = self.radb.insertSpecificationAndTask(0, 0, 'approved', 'observation', now, now+timedelta(hours=1), 'foo', 'CEP4')
        self.assertTrue(result['inserted'])
        spec_id1 = result['specification_id']
        task_id1 = result['task_id']

        task1 = self.radb.getTask(task_id1)
        self.assertTrue(task1)
        self.assertEqual(task_id1, task1['id'])

        t1_claim1 = { 'resource_id': 117,
                      'starttime': task1['starttime'],
                      'endtime': task1['endtime'],
                      'status': 'tentative',
                      'claim_size': 40 }

        # insert 1 claim
        t1_claim_ids = self.radb.insertResourceClaims(task_id1, [t1_claim1], 'foo', 1, 1)
        self.assertEqual(1, len(t1_claim_ids))

        #get claim using t1_claim_ids, and check if db version is equal to original
        t1_claims = self.radb.getResourceClaims(claim_ids=t1_claim_ids)
        self.assertEqual(1, len(t1_claims))
        for key, value in t1_claim1.items():
            if key != 'status':
                self.assertEqual(value, t1_claims[0][key])

        #get claim again via task_id1, and check if db version is equal to original
        t1_claims = self.radb.getResourceClaims(task_ids=task_id1)
        self.assertEqual(1, len(t1_claims))
        for key, value in t1_claim1.items():
            if key != 'status':
                self.assertEqual(value, t1_claims[0][key])

        # try to insert a claim with the wrong (already 'claimed') status. Should rollback, and return no ids.
        t1_claim2 = { 'resource_id': 117,
                      'starttime': task1['starttime'],
                      'endtime': task1['endtime'],
                      'status': 'claimed',
                      'claim_size': 10 }
        t1_faulty_claim_ids = self.radb.insertResourceClaims(task_id1, [t1_claim2], 'foo', 1, 1)
        self.assertEqual(0, len(t1_faulty_claim_ids))

        # try to insert a claim with the wrong (already 'conflict') status. Should rollback, and return no ids.
        t1_claim3 = { 'resource_id': 117,
                      'starttime': task1['starttime'],
                      'endtime': task1['endtime'],
                      'status': 'conflict',
                      'claim_size': 10 }
        t1_faulty_claim_ids = self.radb.insertResourceClaims(task_id1, [t1_claim3], 'foo', 1, 1)
        self.assertEqual(0, len(t1_faulty_claim_ids))

        # try to update the task status to scheduled, should not succeed, since it's claims are not 'claimed' yet.
        self.assertFalse(self.radb.updateTask(task_id1, task_status='scheduled'))
        self.assertEqual('approved', self.radb.getTask(task_id1)['status'])

        # try to update the claim status to claimed, should succeed.
        self.assertTrue(self.radb.updateResourceClaims(t1_claim_ids, status='claimed'))
        self.assertEqual('claimed', self.radb.getResourceClaim(t1_claim_ids[0])['status'])

        # try to update the task status to scheduled again, should succeed this time.
        self.assertTrue(self.radb.updateTask(task_id1, task_status='scheduled'))
        self.assertEqual('scheduled', self.radb.getTask(task_id1)['status'])

        self.assertEqual(0, len(self.radb.get_conflicting_overlapping_claims(t1_claim_ids[0])))
        self.assertEqual(0, len(self.radb.get_conflicting_overlapping_tasks(t1_claim_ids[0])))

        self.assertEqual(40, self.radb.get_max_resource_usage_between(117, task1['starttime'], task1['starttime'], 'claimed')['usage'])

        self.assertEqual(0, self.radb.get_max_resource_usage_between(117, task1['starttime']-timedelta(hours=2), task1['starttime']-timedelta(hours=1), 'claimed')['usage'])

        logger.info('------------------------------ concludes task 1 ------------------------------')
        logger.info('-- now test with a 2nd task, and test resource availability, conflicts etc. --')

        # another task, fully overlapping with task1
        result = self.radb.insertSpecificationAndTask(1, 1, 'approved', 'observation', now, now+timedelta(hours=1), 'foo', 'CEP4')
        self.assertTrue(result['inserted'])
        spec_id2 = result['specification_id']
        task_id2 = result['task_id']

        task2 = self.radb.getTask(task_id2)
        self.assertTrue(task2)

        # insert a claim which won't fit, claim status after insert should be 'conflict' instead of 'tentative'
        t2_claim1 = { 'resource_id': 117,
                      'starttime': task2['starttime'],
                      'endtime': task2['endtime'],
                      'status': 'tentative',
                      'claim_size': 90 }

        t2_claim_ids = self.radb.insertResourceClaims(task_id2, [t2_claim1], 'foo', 1, 1)
        self.assertEqual(1, len(t2_claim_ids))

        # claim status after previous insert should be 'conflict' instead of 'tentative'
        t2_claims = self.radb.getResourceClaims(claim_ids=t2_claim_ids)
        self.assertEqual('conflict', t2_claims[0]['status'])
        # and the task's status should be conflict as well
        self.assertEqual('conflict', self.radb.getTask(task_id2)['status'])

        self.assertEqual(set([t1_claim_ids[0]]), set(c['id'] for c in
                                                     self.radb.get_conflicting_overlapping_claims(t2_claim_ids[0])))
        self.assertEqual(set([task_id1]), set(t['id'] for t in
                                              self.radb.get_conflicting_overlapping_tasks(t2_claim_ids[0])))


        # try to update the task status to scheduled, should not succeed, since it's claims are not 'claimed' yet.
        self.assertFalse(self.radb.updateTask(task_id2, task_status='scheduled'))
        self.assertEqual('conflict', self.radb.getTask(task_id2)['status'])

        # try to update the claim status to claimed, should not succeed, since it still won't fit
        self.assertFalse(self.radb.updateResourceClaims(t2_claim_ids, status='claimed'))
        self.assertEqual('conflict', self.radb.getResourceClaim(t2_claim_ids[0])['status'])

        # do conflict resolution, shift task and claims
        self.assertTrue(self.radb.updateTaskAndResourceClaims(task_id2, starttime=now+timedelta(hours=2), endtime=now+timedelta(hours=3)))
        # now the task and claim status should not be conflict anymore
        self.assertEqual('tentative', self.radb.getResourceClaim(t2_claim_ids[0])['status'])
        self.assertEqual('approved', self.radb.getTask(task_id2)['status'])

        self.assertEqual(0, len(self.radb.get_conflicting_overlapping_claims(t2_claim_ids[0])))
        self.assertEqual(0, len(self.radb.get_conflicting_overlapping_tasks(t2_claim_ids[0])))

        # try to update the claim status to claimed, should succeed now
        self.assertTrue(self.radb.updateResourceClaims(t2_claim_ids, status='claimed'))
        self.assertEqual('claimed', self.radb.getResourceClaim(t2_claim_ids[0])['status'])

        # and try to update the task status to scheduled, should succeed now
        self.assertTrue(self.radb.updateTask(task_id2, task_status='scheduled'))
        self.assertEqual('scheduled', self.radb.getTask(task_id2)['status'])

        self.assertEqual(0, len(self.radb.get_conflicting_overlapping_claims(t2_claim_ids[0])))
        self.assertEqual(0, len(self.radb.get_conflicting_overlapping_tasks(t2_claim_ids[0])))

        logger.info('------------------------------ concludes task 2 ------------------------------')
        logger.info('-- now test with a 3rd task, and test resource availability, conflicts etc. --')

        #make sure we work with the latest info
        task1 = self.radb.getTask(task_id1)
        task2 = self.radb.getTask(task_id2)

        # another task, partially overlapping with both task1 & task3
        result = self.radb.insertSpecificationAndTask(2, 2, 'approved', 'observation',
                                                      task1['starttime'] + (task1['endtime']-task1['starttime'])/2,
                                                      task2['starttime'] + (task2['endtime']-task2['starttime'])/2,
                                                      'foo', 'CEP4')
        self.assertTrue(result['inserted'])
        spec_id2 = result['specification_id']
        task_id3 = result['task_id']

        task3 = self.radb.getTask(task_id3)
        self.assertTrue(task3)

        # insert a claim which won't fit, claim status after insert should be 'conflict' instead of 'tentative'
        t3_claim1 = { 'resource_id': 117,
                      'starttime': task3['starttime'],
                      'endtime': task3['endtime'],
                      'status': 'tentative',
                      'claim_size': 80 }

        t3_claim_ids = self.radb.insertResourceClaims(task_id3, [t3_claim1], 'foo', 1, 1)
        self.assertEqual(1, len(t3_claim_ids))

        # claim status after previous insert should be 'conflict' instead of 'tentative'
        t3_claims = self.radb.getResourceClaims(claim_ids=t3_claim_ids)
        self.assertEqual('conflict', t3_claims[0]['status'])
        # and the task's status should be conflict as well
        self.assertEqual('conflict', self.radb.getTask(task_id3)['status'])
        self.assertEqual(set([t1_claim_ids[0], t2_claim_ids[0]]), set(c['id'] for c in
                                                                      self.radb.get_conflicting_overlapping_claims(t3_claim_ids[0])))
        self.assertEqual(set([task_id1, task_id2]), set(t['id'] for t in
                                                                      self.radb.get_conflicting_overlapping_tasks(t3_claim_ids[0])))

        # try to update the task status to scheduled, should not succeed, since it's claims are not 'claimed' yet.
        self.assertFalse(self.radb.updateTask(task_id3, task_status='scheduled'))
        self.assertEqual('conflict', self.radb.getTask(task_id3)['status'])

        # try to update the claim status to claimed, should not succeed, since it still won't fit
        self.assertFalse(self.radb.updateResourceClaims(t3_claim_ids, status='claimed'))
        self.assertEqual('conflict', self.radb.getResourceClaim(t3_claim_ids[0])['status'])

        # do conflict resolution, shift task away from task1 only (but keep overlapping with task2)
        self.assertTrue(self.radb.updateTaskAndResourceClaims(task_id3, starttime=task1['endtime'] + (task2['starttime']-task1['endtime'])/2))

        # now the task and claim status should still be in conflict
        self.assertEqual('conflict', self.radb.getResourceClaim(t3_claim_ids[0])['status'])
        self.assertEqual('conflict', self.radb.getTask(task_id3)['status'])

        self.assertEqual(set([t2_claim_ids[0]]), set(c['id'] for c in
                                                     self.radb.get_conflicting_overlapping_claims(t3_claim_ids[0])))
        self.assertEqual(set([task_id2]), set(t['id'] for t in
                                              self.radb.get_conflicting_overlapping_tasks(t3_claim_ids[0])))

        # do conflict resolution, reduce claim size (but keep overlapping with task2)
        self.assertTrue(self.radb.updateResourceClaim(t3_claim_ids[0], claim_size=5))

        # now the task and claim status should not be conflict anymore
        self.assertEqual('tentative', self.radb.getResourceClaim(t3_claim_ids[0])['status'])
        self.assertEqual('approved', self.radb.getTask(task_id3)['status'])

        self.assertEqual(0, len(self.radb.get_conflicting_overlapping_claims(t3_claim_ids[0])))
        self.assertEqual(0, len(self.radb.get_conflicting_overlapping_tasks(t3_claim_ids[0])))

        # try to update the claim status to claimed, should succeed now
        self.assertTrue(self.radb.updateResourceClaims(t3_claim_ids, status='claimed'))
        self.assertEqual('claimed', self.radb.getResourceClaim(t3_claim_ids[0])['status'])

        # and try to update the task status to scheduled, should succeed now
        self.assertTrue(self.radb.updateTask(task_id3, task_status='scheduled'))
        self.assertEqual('scheduled', self.radb.getTask(task_id3)['status'])

        # try to trick the radb by resetting the claim_size back to 80 now that it was claimed. Should fail.
        self.assertFalse(self.radb.updateResourceClaim(t3_claim_ids[0], claim_size=80))
        #check if still 5, not 80
        self.assertEqual(5, self.radb.getResourceClaim(t3_claim_ids[0])['claim_size'])
        #and statuses should still be claimed/scheduled
        self.assertEqual('claimed', self.radb.getResourceClaim(t3_claim_ids[0])['status'])
        self.assertEqual('scheduled', self.radb.getTask(task_id3)['status'])

        # suppose the resource_usages table is broken for some reason, fix it....
        # break it first...
        self._execute_query('TRUNCATE TABLE resource_allocation.resource_usage;')
        #check that it's broken
        self.assertNotEqual(40, self.radb.get_max_resource_usage_between(117, task1['starttime'], task1['starttime'], 'claimed')['usage'])
        #fix it
        self.radb.rebuild_resource_usages_table_from_claims()
        #and test again that it's ok
        self.assertEqual(40, self.radb.get_max_resource_usage_between(117, task1['starttime'], task1['starttime'], 'claimed')['usage'])
        self.assertEqual(0, self.radb.get_max_resource_usage_between(117, task1['starttime']-timedelta(hours=2), task1['starttime']-timedelta(hours=1), 'claimed')['usage'])


if __name__ == "__main__":
    os.environ['TZ'] = 'UTC'
    logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s', level=logging.INFO)
    unittest.main()