#!/usr/bin/env python3

# Copyright (C) 2018    ASTRON (Netherlands Institute for Radio Astronomy)
# P.O. Box 2, 7990 AA Dwingeloo, The Netherlands
#
# This file is part of the LOFAR software suite.
# The LOFAR software suite is free software: you can redistribute it and/or
# modify it under the terms of the GNU General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# The LOFAR software suite is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.    See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with the LOFAR software suite. If not, see <http://www.gnu.org/licenses/>.

# $Id:  $

import os
import unittest
import datetime
import logging
import requests
import dateutil.parser
import astropy.coordinates
import json

logger = logging.getLogger(__name__)
logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s', level=logging.INFO)

from lofar.common.test_utils import exit_with_skipped_code_if_skip_integration_tests
exit_with_skipped_code_if_skip_integration_tests()

# Do Mandatory setup step:
# use setup/teardown magic for tmss test database, ldap server and django server
# (ignore pycharm unused import statement, python unittests does use at RunTime the tmss_test_environment_unittest_setup module)
from lofar.sas.tmss.test.tmss_test_environment_unittest_setup import *

# The next import should be done after the 'tmss_test_environment_unittest_setup' magic !!!
from lofar.sas.tmss.tmss.tmssapp.conversions import local_sidereal_time_for_utc_and_station, local_sidereal_time_for_utc_and_longitude


class SiderealTime(unittest.TestCase):

    def test_local_sidereal_time_for_utc_and_longitude_returns_correct_result(self):
        # test result against known correct value
        lst = local_sidereal_time_for_utc_and_longitude(timestamp=datetime.datetime(year=2020, month=1, day=1, hour=12, minute=0, second=0))
        self.assertEqual(str(lst), '19h09m54.9567s')

    def test_local_sidereal_time_for_utc_and_longitude_considers_timestamp(self):
        # test that the results differ for different timestamps
        lst1 = local_sidereal_time_for_utc_and_longitude(timestamp=datetime.datetime(year=2020, month=1, day=1, hour=12, minute=0, second=0))
        lst2 = local_sidereal_time_for_utc_and_longitude(timestamp=datetime.datetime(year=2020, month=1, day=2, hour=12, minute=0, second=0))
        self.assertNotEqual(str(lst1), str(lst2))

    def test_local_sidereal_time_for_utc_and_longitude_considers_longitude(self):
        # test that the results differ for different longitudes
        lst1 = local_sidereal_time_for_utc_and_longitude(timestamp=datetime.datetime(year=2020, month=1, day=1, hour=12, minute=0, second=0), longitude=6.789)
        lst2 = local_sidereal_time_for_utc_and_longitude(timestamp=datetime.datetime(year=2020, month=1, day=1, hour=12, minute=0, second=0), longitude=6.123)
        self.assertNotEqual(str(lst1), str(lst2))

    def test_local_sidereal_time_for_utc_and_station_returns_correct_result(self):
        # assert result against known correct value
        lst = local_sidereal_time_for_utc_and_station(timestamp=datetime.datetime(year=2020, month=1, day=1, hour=12, minute=0, second=0))
        self.assertEqual(str(lst), '19h09m55.0856s')

    def test_local_sidereal_time_for_utc_and_station_considers_timestamp(self):
        # test that the results differ for different timestamps
        lst1 = local_sidereal_time_for_utc_and_station(timestamp=datetime.datetime(year=2020, month=1, day=1, hour=12, minute=0, second=0))
        lst2 = local_sidereal_time_for_utc_and_station(timestamp=datetime.datetime(year=2020, month=1, day=2, hour=12, minute=0, second=0))
        self.assertNotEqual(str(lst1), str(lst2))

    def test_local_sidereal_time_for_utc_and_station_considers_station(self):
        # test that the results differ for different stations
        lst1 = local_sidereal_time_for_utc_and_station(timestamp=datetime.datetime(year=2020, month=1, day=1, hour=12, minute=0, second=0), station="CS002")
        lst2 = local_sidereal_time_for_utc_and_station(timestamp=datetime.datetime(year=2020, month=1, day=1, hour=12, minute=0, second=0), station="DE602")
        self.assertNotEqual(str(lst1), str(lst2))


class UtilREST(unittest.TestCase):

    # utc

    def test_util_utc_returns_timestamp(self):

        # assert local clock differs not too much from returned TMSS system clock
        r = requests.get(BASE_URL + '/util/utc', auth=AUTH)
        self.assertEqual(r.status_code, 200)
        returned_datetime = dateutil.parser.parse(r.content.decode('utf8'))
        current_datetime = datetime.datetime.utcnow()
        delta = abs((returned_datetime - current_datetime).total_seconds())
        self.assertTrue(delta < 60.0)

    # lst

    def test_util_lst_returns_longitude(self):

        # assert returned value is a parseable hms value
        for query in ['/util/lst',
                      '/util/lst?timestamp=2020-01-01T12:00:00',
                      '/util/lst?timestamp=2020-01-01T12:00:00&longitude=54.321',
                      '/util/lst?timestamp=2020-01-01T12:00:00&station=DE609']:
            r = requests.get(BASE_URL + query, auth=AUTH)
            self.assertEqual(r.status_code, 200)
            lon_str = r.content.decode('utf8')
            lon_obj = astropy.coordinates.Longitude(lon_str)
            self.assertEqual(str(lon_obj), lon_str)

    def test_util_lst_considers_timestamp(self):

        # assert returned value matches known result for given timestamp
        r = requests.get(BASE_URL + '/util/lst?timestamp=2020-01-01T12:00:00', auth=AUTH)
        self.assertEqual(r.status_code, 200)
        lon_str = r.content.decode('utf8')
        self.assertEqual('19h09m55.0856s', lon_str)

    def test_util_lst_considers_station(self):

        # assert returned value differs when a different station is given
        r1 = requests.get(BASE_URL + '/util/lst', auth=AUTH)
        r2 = requests.get(BASE_URL + '/util/lst?station=DE602', auth=AUTH)
        self.assertEqual(r1.status_code, 200)
        self.assertEqual(r2.status_code, 200)
        lon_str1 = r1.content.decode('utf8')
        lon_str2 = r2.content.decode('utf8')
        self.assertNotEqual(lon_str1, lon_str2)

    def test_util_lst_considers_longitude(self):
        # assert returned value differs when a different station is given
        r1 = requests.get(BASE_URL + '/util/lst', auth=AUTH)
        r2 = requests.get(BASE_URL + '/util/lst?longitude=12.345', auth=AUTH)
        self.assertEqual(r1.status_code, 200)
        self.assertEqual(r2.status_code, 200)
        lon_str1 = r1.content.decode('utf8')
        lon_str2 = r2.content.decode('utf8')
        self.assertNotEqual(lon_str1, lon_str2)

    # sun_rise_and_set

    def test_util_sun_rise_and_set_returns_json_structure_with_defaults(self):
        r = requests.get(BASE_URL + '/util/sun_rise_and_set', auth=AUTH)
        self.assertEqual(r.status_code, 200)
        r_dict = json.loads(r.content.decode('utf-8'))

        # assert defaults to core and today
        self.assertIn('CS002', r_dict.keys())
        sunrise_start = dateutil.parser.parse(r_dict['CS002']['sunrise'][0]['start'])
        self.assertEqual(datetime.date.today(), sunrise_start.date())

    def test_util_sun_rise_and_set_considers_stations(self):
        stations = ['CS005', 'RS305', 'DE609']
        r = requests.get(BASE_URL + '/util/sun_rise_and_set?stations=%s' % ','.join(stations), auth=AUTH)
        self.assertEqual(r.status_code, 200)
        r_dict = json.loads(r.content.decode('utf-8'))

        # assert station is included in response and timestamps differ
        sunset_start_last = None
        for station in stations:
            self.assertIn(station, r_dict.keys())
            sunset_start = dateutil.parser.parse(r_dict[station]['sunset'][0]['start'])
            if sunset_start_last:
                self.assertNotEqual(sunset_start, sunset_start_last)
            sunset_start_last = sunset_start

    def test_util_sun_rise_and_set_considers_timestamps(self):
        timestamps = ['2020-01-01', '2020-02-22T16-00-00', '2020-3-11', '2020-01-01']
        r = requests.get(BASE_URL + '/util/sun_rise_and_set?timestamps=%s' % ','.join(timestamps), auth=AUTH)
        self.assertEqual(r.status_code, 200)
        r_dict = json.loads(r.content.decode('utf-8'))

        # assert all requested timestamps are included in response (sunrise on same day)
        for i in range(len(timestamps)):
            expected_date = dateutil.parser.parse(timestamps[i]).date()
            response_date = dateutil.parser.parse(r_dict['CS002']['sunrise'][i]['start']).date()
            self.assertEqual(expected_date, response_date)

    def test_util_sun_rise_and_set_returns_correct_date_of_day_sunrise_and_sunset(self):
        timestamps = ['2020-01-01T02-00-00']
        r = requests.get(BASE_URL + '/util/sun_rise_and_set?timestamps=%s' % ','.join(timestamps), auth=AUTH)
        self.assertEqual(r.status_code, 200)
        r_dict = json.loads(r.content.decode('utf-8'))

        # assert day of timestamp matches day of returned values
        expected_date = dateutil.parser.parse(timestamps[0]).date()
        self.assertEqual(expected_date, dateutil.parser.parse(r_dict['CS002']['sunrise'][0]['start']).date())
        self.assertEqual(expected_date, dateutil.parser.parse(r_dict['CS002']['sunrise'][0]['end']).date())
        self.assertEqual(expected_date, dateutil.parser.parse(r_dict['CS002']['day'][0]['start']).date())
        self.assertEqual(expected_date, dateutil.parser.parse(r_dict['CS002']['day'][0]['end']).date())
        self.assertEqual(expected_date, dateutil.parser.parse(r_dict['CS002']['sunset'][0]['start']).date())
        self.assertEqual(expected_date, dateutil.parser.parse(r_dict['CS002']['sunset'][0]['end']).date())

    def test_util_sun_rise_and_set_returns_correct_date_of_night(self):
        timestamps = ['2020-01-01T02-00-00', '2020-01-01T12-00-00']
        r = requests.get(BASE_URL + '/util/sun_rise_and_set?timestamps=%s' % ','.join(timestamps), auth=AUTH)
        self.assertEqual(r.status_code, 200)
        r_dict = json.loads(r.content.decode('utf-8'))

        # assert timestamp before sunrise returns night ending on day of timestamp (last night)
        expected_date = dateutil.parser.parse(timestamps[0]).date()
        response_date = dateutil.parser.parse(r_dict['CS002']['night'][0]['end']).date()
        self.assertEqual(expected_date, response_date)

        # assert timestamp after sunrise returns night starting on day of timestamp (next night)
        expected_date = dateutil.parser.parse(timestamps[1]).date()
        response_date = dateutil.parser.parse(r_dict['CS002']['night'][1]['start']).date()
        self.assertEqual(expected_date, response_date)

    # angular_separation

    def test_util_angular_separation_yields_error_when_no_pointing_is_given(self):
        r = requests.get(BASE_URL + '/util/angular_separation', auth=AUTH)

        # assert error
        self.assertEqual(r.status_code, 500)
        self.assertIn("celestial coordinates", r.content.decode('utf-8'))

    def test_util_angular_separation_returns_json_structure_with_defaults(self):
        r = requests.get(BASE_URL + '/util/angular_separation?angle1=1&angle2=1', auth=AUTH)
        self.assertEqual(r.status_code, 200)
        r_dict = json.loads(r.content.decode('utf-8'))

        # assert default bodies
        for key in ['sun', 'jupiter', 'moon']:
            self.assertIn(key, r_dict.keys())

        # assert timestamp is now and has a value
        returned_datetime = dateutil.parser.parse(list(r_dict['jupiter'].keys())[0])
        current_datetime = datetime.datetime.utcnow()
        delta = abs((returned_datetime - current_datetime).total_seconds())
        self.assertTrue(delta < 60.0)
        self.assertEqual(type(list(r_dict['jupiter'].values())[0]), float)

    def test_util_angular_separation_considers_bodies(self):
        bodies = ['sun', 'neptune', 'mercury']
        r = requests.get(BASE_URL + '/util/angular_separation?angle1=1&angle2=1&bodies=%s' % ','.join(bodies), auth=AUTH)
        self.assertEqual(r.status_code, 200)
        r_dict = json.loads(r.content.decode('utf-8'))

        # assert station is included in response and angles differ
        angle_last = None
        for body in bodies:
            self.assertIn(body, r_dict.keys())
            angle = list(r_dict[body].values())[0]
            if angle_last:
                self.assertNotEqual(angle, angle_last)
            angle_last = angle

    def test_util_angular_separation_considers_timestamps(self):
        timestamps = ['2020-01-01', '2020-02-22T16-00-00', '2020-3-11', '2020-01-01']
        r = requests.get(BASE_URL + '/util/angular_separation?angle1=1&angle2=1&timestamps=%s' % ','.join(timestamps), auth=AUTH)
        self.assertEqual(r.status_code, 200)
        r_dict = json.loads(r.content.decode('utf-8'))

        # assert all requested timestamps yield a response and angles differ
        angle_last = None
        for timestamp in timestamps:
            expected_timestamp = dateutil.parser.parse(timestamp, ignoretz=True).isoformat()
            self.assertIn(expected_timestamp, list(r_dict['jupiter'].keys()))
            angle = r_dict['jupiter'][expected_timestamp]
            if angle_last:
                self.assertNotEqual(angle, angle_last)
            angle_last = angle

    def test_util_angular_separation_considers_coordinates(self):
        test_coords = [(1, 1,"J2000"), (1.1, 1, "J2000"), (1.1, 1.1, "J2000")]
        for coords in test_coords:
            r = requests.get(BASE_URL + '/util/angular_separation?angle1=%s&angle2=%s&direction_type=%s' % coords, auth=AUTH)
            self.assertEqual(r.status_code, 200)
            r_dict = json.loads(r.content.decode('utf-8'))

            # assert all requested timestamps yield a response and angles differ
            angle_last = None
            angle = list(r_dict['jupiter'].values())[0]
            if angle_last:
                self.assertNotEqual(angle, angle_last)
            angle_last = angle

    # target_rise_and_set

    def test_util_target_rise_and_set_returns_json_structure_with_defaults(self):
        r = requests.get(BASE_URL + '/util/target_rise_and_set?angle1=0.5&angle2=0.5', auth=AUTH)
        self.assertEqual(r.status_code, 200)
        r_dict = json.loads(r.content.decode('utf-8'))

        # defaults are CS002 and today
        self.assertIn('CS002', r_dict.keys())

        # assert target sets within 24h after now and rises within 24h before it sets
        expected_date = datetime.datetime.utcnow()
        target_rise = dateutil.parser.parse(r_dict['CS002'][0]['rise'])
        target_set = dateutil.parser.parse(r_dict['CS002'][0]['set'])
        self.assertTrue(0 < (target_set - expected_date).total_seconds() < 86400)
        self.assertTrue(0 < (target_set - target_rise).total_seconds() < 86400)

    def test_util_target_rise_and_set_considers_stations(self):
        stations = ['CS005', 'RS305', 'DE609']
        r = requests.get(BASE_URL + '/util/target_rise_and_set?angle1=0.5&angle2=0.5&stations=%s' % ','.join(stations), auth=AUTH)
        self.assertEqual(r.status_code, 200)
        r_dict = json.loads(r.content.decode('utf-8'))

        # assert station is included in response and timestamps differ
        target_rise_last = None
        for station in stations:
            self.assertIn(station, r_dict.keys())
            target_rise = dateutil.parser.parse(r_dict[station][0]['rise'])
            if target_rise_last:
                self.assertNotEqual(target_rise, target_rise_last)
            target_rise_last = target_rise

    def test_util_target_rise_and_set_considers_timestamps(self):
        timestamps = ['2020-01-01', '2020-02-22T16-00-00', '2020-3-11', '2020-01-01']
        r = requests.get(BASE_URL + '/util/target_rise_and_set?angle1=0.5&angle2=0.5&timestamps=%s' % ','.join(timestamps), auth=AUTH)
        self.assertEqual(r.status_code, 200)
        r_dict = json.loads(r.content.decode('utf-8'))

        # assert all requested timestamps are included in response and either rise or set are in that day
        for i in range(len(timestamps)):
            expected_date = dateutil.parser.parse(timestamps[i]).date()
            response_rise_date = dateutil.parser.parse(r_dict['CS002'][i]['rise']).date()
            response_set_date = dateutil.parser.parse(r_dict['CS002'][i]['set']).date()
            self.assertTrue(expected_date == response_rise_date or expected_date == response_set_date)

    def test_util_target_rise_and_set_returns_correct_date_of_target_rise_and_set(self):
        timestamps = ['2020-01-01T02-00-00']
        r = requests.get(BASE_URL + '/util/target_rise_and_set?angle1=0.5&angle2=0.5&timestamps=%s' % ','.join(timestamps), auth=AUTH)
        self.assertEqual(r.status_code, 200)
        r_dict = json.loads(r.content.decode('utf-8'))

        # assert day of timestamp matches day of returned rise
        expected_date = dateutil.parser.parse(timestamps[0]).date()
        target_rise = dateutil.parser.parse(r_dict['CS002'][0]['rise'])
        target_set = dateutil.parser.parse(r_dict['CS002'][0]['set'])
        self.assertTrue(expected_date == target_rise.date() or expected_date == target_set.date())

        # assert set time falls in the 24h after rise time
        self.assertTrue(target_set - target_rise > datetime.timedelta(0) and target_set - target_rise < datetime.timedelta(days=1))

    def test_util_target_rise_and_set_considers_coordinates(self):
        test_coords = [(0.5, 0.5, "J2000"), (0.6, 0.5, "J2000"), (0.6, 0.6, "J2000")]
        for coords in test_coords:
            r = requests.get(BASE_URL + '/util/target_rise_and_set?angle1=%s&angle2=%s&direction_type=%s' % coords, auth=AUTH)
            self.assertEqual(r.status_code, 200)
            r_dict = json.loads(r.content.decode('utf-8'))

            # assert all requested coordinates yield a response and times differ
            rise_last = None
            rise = r_dict['CS002'][0]['rise']
            if rise_last:
                self.assertNotEqual(rise, rise_last)
            rise_last = rise

    def test_util_target_rise_and_set_considers_horizon(self):
        test_horizons = [0.1, 0.2, 0.3]
        rise_last = None
        for horizon in test_horizons:
            r = requests.get(BASE_URL + '/util/target_rise_and_set?angle1=0.5&angle2=0.5&horizon=%s' % horizon, auth=AUTH)
            self.assertEqual(r.status_code, 200)
            r_dict = json.loads(r.content.decode('utf-8'))

            # assert all requested horizons yield a response and times differ
            rise = r_dict['CS002'][0]['rise']
            if rise_last:
                self.assertNotEqual(rise, rise_last)
            rise_last = rise

    def test_util_target_rise_and_set_detects_when_target_above_horizon(self):

        # assert always below and always above are usually false
        r = requests.get(BASE_URL + '/util/target_rise_and_set?angle1=0.5&angle2=0.8&timestamps=2020-01-01&horizon=0.2', auth=AUTH)
        self.assertEqual(r.status_code, 200)
        r_dict = json.loads(r.content.decode('utf-8'))
        self.assertIsNotNone(r_dict['CS002'][0]['rise'])
        self.assertIsNotNone(r_dict['CS002'][0]['set'])
        self.assertFalse(r_dict['CS002'][0]['always_below_horizon'])
        self.assertFalse(r_dict['CS002'][0]['always_above_horizon'])

        # assert rise and set are None and flag is true when target is always above horizon
        r = requests.get(BASE_URL + '/util/target_rise_and_set?angle1=0.5&angle2=0.8&timestamps=2020-01-01&horizon=0.1', auth=AUTH)
        self.assertEqual(r.status_code, 200)
        r_dict = json.loads(r.content.decode('utf-8'))
        self.assertIsNone(r_dict['CS002'][0]['rise'])
        self.assertIsNone(r_dict['CS002'][0]['set'])
        self.assertTrue(r_dict['CS002'][0]['always_above_horizon'])
        self.assertFalse(r_dict['CS002'][0]['always_below_horizon'])

        # assert rise and set are None and flag is true when target is always below horizon
        r = requests.get(BASE_URL + '/util/target_rise_and_set?angle1=0.5&angle2=-0.5&timestamps=2020-01-01&horizon=0.2', auth=AUTH)
        self.assertEqual(r.status_code, 200)
        r_dict = json.loads(r.content.decode('utf-8'))
        self.assertIsNone(r_dict['CS002'][0]['rise'])
        self.assertIsNone(r_dict['CS002'][0]['set'])
        self.assertFalse(r_dict['CS002'][0]['always_above_horizon'])
        self.assertTrue(r_dict['CS002'][0]['always_below_horizon'])

    # target transit

    def test_util_target_transit_returns_json_structure_with_defaults(self):
        r = requests.get(BASE_URL + '/util/target_transit?angle1=0.5&angle2=0.5', auth=AUTH)
        self.assertEqual(r.status_code, 200)
        r_dict = json.loads(r.content.decode('utf-8'))

        # defaults are CS002 and today
        self.assertIn('CS002', r_dict.keys())

        # assert returned timestamp is no further than 12h away from now
        expected_time = datetime.datetime.utcnow()
        returned_time = dateutil.parser.parse(r_dict['CS002'][0])
        time_diff = abs(expected_time - returned_time)
        self.assertTrue(time_diff <= datetime.timedelta(days=0.5))

    def test_util_target_transit_considers_stations(self):
        stations = ['CS005', 'RS305', 'DE609']
        r = requests.get(BASE_URL + '/util/target_transit?angle1=0.5&angle2=0.5&stations=%s' % ','.join(stations), auth=AUTH)
        self.assertEqual(r.status_code, 200)
        r_dict = json.loads(r.content.decode('utf-8'))

        # assert station is included in response and timestamps differ
        target_transit_last = None
        for station in stations:
            self.assertIn(station, r_dict.keys())
            target_transit = dateutil.parser.parse(r_dict[station][0])
            if target_transit_last:
                self.assertNotEqual(target_transit, target_transit_last)
            target_transit_last = target_transit

    def test_util_target_transit_considers_timestamps(self):
        timestamps = ['2020-01-01', '2020-02-22T16-00-00', '2020-3-11']
        r = requests.get(BASE_URL + '/util/target_transit?angle1=0.5&angle2=0.5&timestamps=%s' % ','.join(timestamps), auth=AUTH)
        self.assertEqual(r.status_code, 200)
        r_dict = json.loads(r.content.decode('utf-8'))

        # assert all requested timestamps yield a different response
        transit_last = None
        for i in range(len(timestamps)):
            transit = r_dict['CS002'][i]
            if transit_last:
                self.assertNotEqual(transit, transit_last)
            transit_last = transit

    def test_util_target_transit_returns_correct_date_of_target_transit(self):
        timestamps = ['2020-01-01T02-00-00']
        r = requests.get(BASE_URL + '/util/target_transit?angle1=0.5&angle2=0.5&timestamps=%s' % ','.join(timestamps), auth=AUTH)
        self.assertEqual(r.status_code, 200)
        r_dict = json.loads(r.content.decode('utf-8'))

        # assert transit time is no further than 12h from requested time
        requested_time = dateutil.parser.parse(timestamps[0]).replace(tzinfo=None)
        returned_time = dateutil.parser.parse(r_dict['CS002'][0])
        time_diff = abs(requested_time - returned_time)
        self.assertTrue(time_diff <= datetime.timedelta(days=0.5))

    def test_util_target_transit_considers_coordinates(self):
        test_coords = [(0.5, 0.5, "J2000"), (0.6, 0.5, "J2000"), (0.6, 0.6, "J2000")]
        transit_last = None
        for coords in test_coords:
            r = requests.get(BASE_URL + '/util/target_transit?angle1=%s&angle2=%s&direction_type=%s' % coords, auth=AUTH)
            self.assertEqual(r.status_code, 200)
            r_dict = json.loads(r.content.decode('utf-8'))

            # assert all requested coordinates yield a response and times differ
            transit = r_dict['CS002'][0]
            if transit_last:
                self.assertNotEqual(transit, transit_last)
            transit_last = transit

if __name__ == "__main__":
    os.environ['TZ'] = 'UTC'
    unittest.main()