#!/usr/bin/python

# Copyright (C) 2012-2015    ASTRON (Netherlands Institute for Radio Astronomy)
# P.O. Box 2, 7990 AA Dwingeloo, The Netherlands
#
# This file is part of the LOFAR software suite.
# The LOFAR software suite is free software: you can redistribute it and/or
# modify it under the terms of the GNU General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# The LOFAR software suite is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.    See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with the LOFAR software suite. If not, see <http://www.gnu.org/licenses/>.

# $Id$
import unittest
import psycopg2
import os
from datetime import datetime, timedelta
from dateutil import parser
import logging

logger = logging.getLogger(__name__)

try:
    import testing.postgresql
except ImportError as e:
    print((str(e)))
    print('Please install python package testing.postgresql: sudo pip install testing.postgresql')
    exit(3)  # special lofar test exit code: skipped test

from lofar.common.dbcredentials import Credentials
from lofar.sas.resourceassignment.database.radb import RADatabase


# Create shared test database for better performance
database_credentials = None
Postgresql = None

def setUpModule():
    global database_credentials, Postgresql
    database_credentials = Credentials()
    Postgresql = testing.postgresql.PostgresqlFactory(cache_initialized_db=True)


def tearDownModule():
    # clear cached database at end of tests
    logger.info('tearDownModule')
    Postgresql.clear_cache()


class RADBCommonTest(unittest.TestCase):

    def setUp(self):
        logger.info('setting up test RA database...')
        # connect to shared test db
        self.postgresql = Postgresql()  # fresh db instead of shared one: self.postgresql = testing.postgresql.Postgresql()

        # set up fixtures
        # Note: In theory, this can be moved to the PostgresqlFactory call as kwarg 'on_initialized=populatedb'
        # ...but for some reason that was much slower than keeping it here.
        self._setup_database()

        # update credentials (e.g. port changes for each test)
        database_credentials.host = self.postgresql.dsn()['host']
        database_credentials.database = self.postgresql.dsn()['database']
        database_credentials.port = self.postgresql.dsn()['port']

        # connect with useradministration role for tests
        self.connection = psycopg2.connect(host=database_credentials.host,
                                           user=database_credentials.user,
                                           password=database_credentials.password,
                                           dbname=database_credentials.database,
                                           port=database_credentials.port)

        # set up radb python module
        self.radb = RADatabase(database_credentials, log_queries=True)
        logger.info('...finished setting up test RA database')

    def tearDown(self):
        logger.info('removing test RA database...')
        self.connection.close()
        # self.Postgresql.clear_cache() # for fresh db during setUp, do instead:
        self.postgresql.stop()

    def _setup_database(self):

        # connect to db as root
        conn = psycopg2.connect(**self.postgresql.dsn())
        cursor = conn.cursor()

        # set credentials to be used during tests
        database_credentials.user = 'resourceassignment'
        database_credentials.password = 'blabla'  # cannot be empty...

        # create user role
        # Note: NOSUPERUSER currently raises "permission denied for schema virtual_instrument"
        # Maybe we want to sort out user creation and proper permissions in the sql scripts?
        query = "CREATE USER %s WITH SUPERUSER PASSWORD '%s'" % (
            database_credentials.user,
            database_credentials.password)
        cursor.execute(query)

        # populate db tables
        # These are applied in given order to set up test db
        # Note: cannot use create_and_populate_database.sql since '\i' is not understood by cursor.execute()
        sql_basepath = os.environ['LOFARROOT'] + "/share/radb/sql/"
        sql_createdb_paths = [sql_basepath + "create_database.sql",
                              sql_basepath + "/add_resource_allocation_statics.sql",
                              sql_basepath + "/add_virtual_instrument.sql",
                              sql_basepath + "/add_notifications.sql",
                              sql_basepath + "/add_functions_and_triggers.sql"]

        for sql_path in sql_createdb_paths:
            logger.debug("setting up database. applying sql file: %s", sql_path)
            with open(sql_path) as sql:
                cursor.execute(sql.read())

        cursor.close()
        conn.commit()
        conn.close()

    def _execute_query(self, query, fetch=False):
        cursor = self.connection.cursor()
        cursor.execute(query)
        ret = None
        if fetch:
            ret = cursor.fetchall()
        cursor.close()
        self.connection.commit()
        return ret

    # --- tests start here

    # integrity tests of postgres database itself
    #
    # Note: These are meant to make sure the setup generally works and all sql scripts were applied.
    # I don't see much benefit in full coverage here since it should be all be tested through RADataBase functionality.
    # Of course new tests can be added here where db functionality like triggers should be tested separately from the
    # Python part of the job.

    # database created?
    def test_select_tables_contains_tables_for_each_schema(self):
        query = "SELECT table_schema,table_name FROM information_schema.tables"
        fetch = self._execute_query(query, fetch=True)
        self.assertTrue('resource_allocation' in str(fetch))
        self.assertTrue('resource_monitoring' in str(fetch))
        self.assertTrue('virtual_instrument' in str(fetch))

    # resource allocation_statics there?
    def test_select_task_types_contains_obervation(self):
        query = "SELECT * FROM resource_allocation.task_type"
        fetch = self._execute_query(query, fetch=True)
        self.assertTrue('observation' in str(fetch))

    # virtual instrument there?
    def test_select_virtualinstrument_units_contain_rcuboard(self):
        query = "SELECT * FROM virtual_instrument.unit"
        fetch = self._execute_query(query, fetch=True)
        self.assertTrue('rcu_board' in str(fetch))


if __name__ == "__main__":
    os.environ['TZ'] = 'UTC'
    logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s', level=logging.INFO)
    unittest.main()