import datetime
from collections import OrderedDict

import coreapi
import coreschema
from django.db.models import Count

from math import ceil
from django.db.models import Window, F

from rest_framework import status
from rest_framework.response import Response
from rest_framework.schemas import ManualSchema
from rest_framework.views import APIView

from lofar.maintenance.monitoringdb.models.rtsm import RTSMObservation
from lofar.maintenance.monitoringdb.models.station import Station
from lofar.maintenance.monitoringdb.models.station_test import StationTest
from lofar.maintenance.monitoringdb.models.component_error import ComponentError
from lofar.maintenance.monitoringdb.models.rtsm import RTSMErrorSummary


class ControllerStationOverview(APIView):
    """
    Overview of the latest tests performed on the stations
    """

    DEFAULT_STATION_GROUP = 'A'
    DEFAULT_ONLY_ERRORS = True
    DEFAULT_N_STATION_TESTS = 4
    DEFAULT_N_RTSM = 4

    queryset = StationTest.objects.all()
    schema = ManualSchema(fields=[
        coreapi.Field(
            "station_group",
            required=False,
            location='query',
            schema=coreschema.Enum(['C', 'R', 'I', 'A'], description=
            'Station group to select for choices are [C|R|I|ALL]',
                                   )
        ),
        coreapi.Field(
            "n_station_tests",
            required=False,
            location='query',
            schema=coreschema.Integer(description='number of station tests to select',
                                      minimum=1)
        ),
        coreapi.Field(
            "n_rtsm",
            required=False,
            location='query',
            schema=coreschema.Integer(description='number of station tests to select',
                                      minimum=1)
        ),
        coreapi.Field(
            "errors_only",
            required=False,
            location='query',
            schema=coreschema.Boolean(
                description='displays or not only the station with more than one error')
        )
    ]
    )

    def get(self, request, format=None):
        errors_only = request.query_params.get('errors_only', self.DEFAULT_ONLY_ERRORS)
        station_group = request.query_params.get('station_group', self.DEFAULT_STATION_GROUP)
        n_station_tests = int(
            request.query_params.get('n_station_tests', self.DEFAULT_N_STATION_TESTS))
        n_rtsm = int(request.query_params.get('n_rtsm', self.DEFAULT_N_RTSM))

        station_entities = Station.objects.all()
        for group in station_group:
            if group is not 'A':
                station_entities = station_entities.filter(type=group)

        # Since django preferes a ordered dict over a dict we make it happy... for now
        response_payload = list()
        for station_entity in station_entities:
            station_payload = OrderedDict()

            station_payload['station_name'] = station_entity.name

            station_test_list = StationTest.objects.filter(
                station__name=station_entity.name).order_by('-end_datetime')[:n_station_tests]
            rtsm_list = RTSMObservation.objects.filter(
                station__name=station_entity.name).order_by('-end_datetime')[:n_rtsm]

            station_payload['station_tests'] = list()
            for station_test in station_test_list:
                station_test_payload = OrderedDict()
                component_errors = station_test.component_errors

                station_test_payload[
                    'total_component_errors'] = station_test.component_errors.count()
                station_test_payload['start_datetime'] = station_test.start_datetime
                station_test_payload['end_datetime'] = station_test.end_datetime
                station_test_payload['checks'] = station_test.checks
                component_errors_summary = component_errors. \
                    values('component__type', 'type').annotate(
                    total=Count('type')).order_by('-total')
                component_errors_summary_dict = OrderedDict()
                for item in component_errors_summary:
                    item_component_type = item['component__type']
                    item_error_type = item['type']
                    item_error_total = item['total']

                    if item_component_type not in component_errors_summary_dict:
                        component_errors_summary_dict[item_component_type] = OrderedDict()

                    component_errors_summary_dict[item_component_type][item_error_type] = \
                        item_error_total
                station_test_payload['component_error_summary'] = component_errors_summary_dict

                station_payload['station_tests'].append(station_test_payload)

            station_payload['rtsm'] = list()
            for rtsm in rtsm_list:
                rtsm_payload = OrderedDict()
                rtsm_payload['observation_id'] = rtsm.observation_id
                rtsm_payload['start_datetime'] = rtsm.start_datetime
                rtsm_payload['end_datetime'] = rtsm.end_datetime

                unique_modes = [item['mode'] for item in rtsm.errors.values('mode').distinct()]
                rtsm_payload['mode'] = unique_modes
                rtsm_payload['total_component_errors'] = rtsm.errors_summary.count()

                errors_summary = OrderedDict()

                errors_summary_query = rtsm.errors_summary.annotate(total=Count('error_type')).values(
                    'error_type', 'total').distinct()

                for error_summary in errors_summary_query:
                    errors_summary[error_summary['error_type']] = error_summary['total']

                rtsm_payload['error_summary'] = errors_summary
                station_payload['rtsm'].append(rtsm_payload)

            response_payload.append(station_payload)
        if errors_only and errors_only is not 'false':
            response_payload = filter(
                lambda station_entry:
                len(station_entry['station_tests']) + len(station_entry['rtsm']) > 0,
                response_payload)
        response_payload = sorted(response_payload, key=lambda item: item['station_name'])
        return Response(status=status.HTTP_200_OK, data=response_payload)


class ControllerStationTestsSummary(APIView):
    """
    Overview of the latest station tests performed on the stations # lookback days before now
    """

    DEFAULT_STATION_GROUP = 'A'
    DEFAULT_ONLY_ERRORS = True
    DEFAULT_LOOKBACK_TIME_IN_DAYS = 7

    queryset = StationTest.objects.all()
    schema = ManualSchema(fields=[
        coreapi.Field(
            "station_group",
            required=False,
            location='query',
            schema=coreschema.Enum(['C', 'R', 'I', 'A'],
                                   description='Station group to select for choices are [C|R|I|ALL]')
        ),
        coreapi.Field(
            "errors_only",
            required=False,
            location='query',
            schema=coreschema.Boolean(
                description='displays or not only the station with more than one error')
        ),
        coreapi.Field(
            "lookback_time",
            required=False,
            location='query',
            schema=coreschema.Integer(description='number of days from now (default 7)',
                                      minimum=1)
        )
    ]
    )

    @staticmethod
    def parse_date(date):
        expected_format = '%Y-%m-%d'
        try:
            parsed_date = datetime.datetime.strptime(date, expected_format)
            return parsed_date
        except Exception as e:
            raise ValueError('cannot parse %s with format %s - %s' % (date, expected_format, e))

    def validate_query_parameters(self, request):
        self.errors_only = request.query_params.get('errors_only', self.DEFAULT_ONLY_ERRORS)
        self.station_group = request.query_params.get('station_group', self.DEFAULT_STATION_GROUP)

        self.lookback_time = datetime.timedelta(int(request.query_params.get('lookback_time',
                                                            self.DEFAULT_LOOKBACK_TIME_IN_DAYS)))

    def get(self, request, format=None):
        try:
            self.validate_query_parameters(request)
        except ValueError as e:
            return Response(status=status.HTTP_406_NOT_ACCEPTABLE,
                            data='Please specify the correct parameters: %s' % (e,))
        except KeyError as e:
            return Response(status=status.HTTP_406_NOT_ACCEPTABLE,
                            data='Please specify all the required parameters: %s' % (e,))

        station_test_list = StationTest.objects\
            .filter(start_datetime__gte=datetime.date.today()-self.lookback_time)\
            .order_by('-start_datetime', 'station__name')
        for group in self.station_group:
            if group is not 'A':
                station_entities = station_test_list.filter(type=group)

        # Since django preferes a ordered dict over a dict we make it happy... for now
        response_payload = list()

        for station_test in station_test_list:

            station_test_payload = OrderedDict()
            component_errors = station_test.component_errors
            station_test_payload['station_name'] = station_test.station.name

            station_test_payload[
                'total_component_errors'] = station_test.component_errors.count()
            station_test_payload['date'] = station_test.start_datetime.strftime('%Y-%m-%d')
            station_test_payload['start_datetime'] = station_test.start_datetime
            station_test_payload['end_datetime'] = station_test.end_datetime
            station_test_payload['checks'] = station_test.checks
            component_errors_summary = component_errors. \
                values('component__type', 'type').annotate(
                total=Count('type')).order_by('-total')
            component_errors_summary_dict = OrderedDict()
            for item in component_errors_summary:
                item_component_type = item['component__type']
                item_error_type = item['type']
                item_error_total = item['total']

                if item_component_type not in component_errors_summary_dict:
                    component_errors_summary_dict[item_component_type] = OrderedDict()

                component_errors_summary_dict[item_component_type][item_error_type] = \
                    item_error_total
            station_test_payload['component_error_summary'] = component_errors_summary_dict

            response_payload.append(station_test_payload)
        if self.errors_only and self.errors_only is not 'false':
            response_payload = filter(
                lambda station_test_entry:
                station_test_entry['total_component_errors'] > 0,
                response_payload)

        return Response(status=status.HTTP_200_OK, data=response_payload)


class ControllerLatestObservations(APIView):
    """
    Overview of the latest observations performed on the stations
    """

    DEFAULT_STATION_GROUP = 'A'
    DEFAULT_ONLY_ERRORS = True

    queryset = StationTest.objects.all()
    schema = ManualSchema(fields=[
        coreapi.Field(
            "station_group",
            required=False,
            location='query',
            schema=coreschema.Enum(['C', 'R', 'I', 'A'], description=
            'Station group to select for choices are [C|R|I|ALL]',
                                   )
        ),
        coreapi.Field(
            "errors_only",
            required=False,
            location='query',
            schema=coreschema.Boolean(
                description='displays or not only the station with more than one error')
        ),
        coreapi.Field(
            "from_date",
            required=True,
            location='query',
            schema=coreschema.String(
                description='select rtsm from date (ex. YYYY-MM-DD)')
        )
    ]
    )

    @staticmethod
    def parse_date(date):
        expected_format = '%Y-%m-%d'
        try:
            parsed_date = datetime.datetime.strptime(date, expected_format)
            return parsed_date
        except Exception as e:
            raise ValueError('cannot parse %s with format %s - %s' % (date, expected_format, e))

    def compute_rtsm_observation_summary(self, rtsm_errors):
        errors_summary = OrderedDict()

        errors_summary_query = rtsm_errors.annotate(total=
                                                    Window(expression=Count('rcu'),
                                                           partition_by=[F(
                                                               'error_type')])).values(
            'error_type', 'total').distinct()

        for error_summary in errors_summary_query:
            errors_summary[error_summary['error_type']] = error_summary['total']
        return errors_summary

    def validate_query_parameters(self, request):
        self.errors_only = request.query_params.get('errors_only', self.DEFAULT_ONLY_ERRORS)
        self.station_group = request.query_params.get('station_group', self.DEFAULT_STATION_GROUP)

        start_date = request.query_params.get('from_date')
        self.from_date = ControllerLatestObservations.parse_date(start_date)

    def get(self, request, format=None):
        try:
            self.validate_query_parameters(request)
        except ValueError as e:
            return Response(status=status.HTTP_406_NOT_ACCEPTABLE,
                            data='Please specify the date in the format YYYY-MM-DD: %s' % (e,))
        except KeyError as e:
            return Response(status=status.HTTP_406_NOT_ACCEPTABLE,
                            data='Please specify both the start and the end date: %s' % (e,))

        rtsm_observation_entities = RTSMObservation.objects.order_by('-start_datetime').filter(
            start_datetime__gte=self.from_date)
        for group in self.station_group:
            if group is not 'A':
                rtsm_observation_entities = rtsm_observation_entities.filter(station__type=group)

        # Since django preferes a ordered dict over a dict we make it happy... for now
        response_payload = list()
        for rtsm_observation_entity in rtsm_observation_entities.values('observation_id',
                                                                        'start_datetime',
                                                                        'end_datetime'). \
                distinct():
            observation_payload = OrderedDict()

            observation_payload['observation_id'] = rtsm_observation_entity['observation_id']
            observation_payload['start_datetime'] = rtsm_observation_entity['start_datetime']
            observation_payload['end_datetime'] = rtsm_observation_entity['end_datetime']

            rtsm_list = RTSMObservation.objects.filter(
                start_datetime__gte=self.from_date,
                observation_id=rtsm_observation_entity['observation_id']). \
                order_by('-start_datetime')
            unique_modes = [item['errors__mode'] for item in rtsm_list.values('errors__mode').distinct()]

            observation_payload['mode'] = unique_modes
            observation_payload['total_component_errors'] = 0

            station_list = rtsm_list.values('station').distinct()
            station_involved_list = list()

            for station in station_list:
                rtsm_entry_per_station = rtsm_list.filter(station__pk=station['station']).first()

                station_summary = OrderedDict()
                station_summary['station_name'] = rtsm_entry_per_station.station.name

                station_summary['n_errors'] = \
                    rtsm_entry_per_station.errors.values('rcu').distinct().count()
                station_summary['component_error_summary'] = self.compute_rtsm_observation_summary(
                    rtsm_entry_per_station.errors_summary)

                station_involved_list.append(station_summary)

                observation_payload['total_component_errors'] += station_summary['n_errors']

            station_involved_list = sorted(station_involved_list,
                                           key=lambda rtsm_per_station: rtsm_per_station['n_errors'],
                                           reverse=True)
            observation_payload['station_involved'] = station_involved_list

            response_payload.append(observation_payload)
        if self.errors_only and self.errors_only != 'false':
            print(self.errors_only)
            response_payload = filter(
                lambda station_entry:
                len(station_entry['total_component_errors']) > 0,
                response_payload)
        response_payload = sorted(response_payload, key=lambda item: item['total_component_errors'],
                                  reverse=True)
        return Response(status=status.HTTP_200_OK, data=response_payload)


class ControllerStationTestStatistics(APIView):
    """

/views/ctrl_stationtest_statistics:

parameters:
station_group [C|R|I|ALL] (optional, default ALL)
test_type [RTSM|STATIONTEST|BOTH] (optional, default BOTH)
from_date #DATE
to_date #DATE
averaging_interval: #TIMESPAN
result:
{
  start_date : #DATE,
  end_date: #DATE,
  averaging_interval: #INTERVAL,
  error_per_station: [{
    time: #DATE,
    station_name: <station_name>
    n_errors:     #nr_errors int
    }, ...],
   error_per_error_type:  [{
    time: #DATE,
    error_type: <error_type>
    n_errors:     #nr_errors int
    }, ...],
 },
 ....
]
    """
    DEFAULT_STATION_GROUP = 'A'
    DEFAULT_TEST_TYPE = 'B'

    queryset = StationTest.objects.all()
    schema = ManualSchema(fields=[
        coreapi.Field(
            "test_type",
            required=False,
            location='query',
            schema=coreschema.Enum(['R', 'S', 'B'],
                                   description='select the type of test possible values are (R, RTSM),'
                                               ' (S, Station test), (B, both)[DEFAULT=B]',
                                   )
        ),
        coreapi.Field(
            "station_group",
            required=False,
            location='query',
            schema=coreschema.Enum(['C', 'R', 'I', 'A'], description=
            'Station group to select for choices are [C|R|I|ALL]',
                                   )
        ),
        coreapi.Field(
            "from_date",
            required=True,
            location='query',
            schema=coreschema.String(
                description='select tests from date (ex. YYYY-MM-DD)')
        ),
        coreapi.Field(
            "to_date",
            required=True,
            location='query',
            schema=coreschema.String(
                description='select tests to date (ex. YYYY-MM-DD)')
        ),
        coreapi.Field(
            "averaging_interval",
            required=True,
            location='query',
            schema=coreschema.Integer(
                description='averaging interval in days')
        )
    ]
    )

    @staticmethod
    def parse_date(date):
        expected_format = '%Y-%m-%d'
        try:
            parsed_date = datetime.datetime.strptime(date, expected_format)
            return parsed_date
        except Exception as e:
            raise ValueError('cannot parse %s with format %s - %s' % (date, expected_format, e))

    def validate_query_parameters(self, request):
        self.station_group = request.query_params.get('station_group', self.DEFAULT_STATION_GROUP)
        if self.station_group not in ['C', 'R', 'I', 'A']:
            raise ValueError('station_group is not one of [C,R,I,A]')

        from_date = request.query_params.get('from_date')
        self.from_date = ControllerLatestObservations.parse_date(from_date)

        to_date = request.query_params.get('to_date')
        self.to_date = ControllerLatestObservations.parse_date(to_date)

        self.test_type = request.query_params.get('test_type', self.DEFAULT_TEST_TYPE)

        if self.test_type not in ['R', 'S', 'B']:
            raise ValueError('test_type is not one of [R,S,B]')

        self.averaging_interval = datetime.timedelta(int(request.query_params.get('averaging_interval')))

    def compute_errors_per_station(self, from_date, to_date, central_time, station_group, test_type):

        component_errors = ComponentError.objects.all()
        rtsm_summary_errors = RTSMErrorSummary.objects.all()

        if station_group:
            component_errors = component_errors.filter(station_test__station__type=station_group)
            rtsm_summary_errors = rtsm_summary_errors.filter(observation__station__type=station_group)

        station_test_results = []
        rtsm_results = []
        if test_type in ['S', 'B']:
            station_test_results = component_errors. \
                filter(station_test__start_datetime__gt=from_date, station_test__start_datetime__lt=to_date). \
                values('station_test__station__name'). \
                annotate(n_errors=Count('station_test__station__name'))
        if test_type in ['R', 'B']:
            rtsm_results = rtsm_summary_errors. \
                filter(observation__start_datetime__gt=from_date, observation__start_datetime__lt=to_date). \
                values('observation__station__name'). \
                annotate(n_errors=Count('observation__station__name'))

        errors_per_station_in_bin = dict()

        if test_type in ['S', 'B']:
            for result in station_test_results:
                station_name = result['station_test__station__name']
                errors_per_station_in_bin[station_name] = dict(station_name=station_name,
                                                               n_errors=result['n_errors'],
                                                               time=central_time)

        if test_type in ['R', 'B']:
            for result in rtsm_results:
                station_name = result['observation__station__name']
                if station_name not in errors_per_station_in_bin:
                    errors_per_station_in_bin[station_name] = dict(station_name=station_name,
                                                                   n_errors=result['n_errors'],
                                                                   time=central_time)
                else:
                    errors_per_station_in_bin[station_name]['n_errors'] += result['n_errors']

        return errors_per_station_in_bin.values()

    def compute_errors_per_type(self, from_date, to_date, central_time, station_group, test_type):

        component_errors = ComponentError.objects.all()
        rtsm_summary_errors = RTSMErrorSummary.objects.all()

        station_test_results = []
        rtsm_results = []

        if station_group:
            component_errors = component_errors.filter(station_test__station__type=station_group)
            rtsm_summary_errors = rtsm_summary_errors.filter(observation__station__type=station_group)

        if test_type in ['S', 'B']:
            station_test_results = component_errors. \
                filter(station_test__start_datetime__gt=from_date, station_test__start_datetime__lt=to_date). \
                values('type'). \
                annotate(n_errors=Count('type'))
        if test_type in ['R', 'B']:
            rtsm_results = rtsm_summary_errors. \
                filter(observation__start_datetime__gt=from_date, observation__start_datetime__lt=to_date). \
                values('error_type'). \
                annotate(n_errors=Count('error_type'))

        errors_per_error_type_in_bin = dict()

        if test_type in ['S', 'B']:
            for result in station_test_results:
                error_type = result['type']
                errors_per_error_type_in_bin[error_type] = dict(error_type=error_type,
                                                                n_errors=result['n_errors'],
                                                                time=central_time)
        if test_type in ['R', 'B']:
            for result in rtsm_results:
                error_type = result['error_type']
                if error_type not in errors_per_error_type_in_bin:
                    errors_per_error_type_in_bin[error_type] = dict(error_type=error_type,
                                                                    n_errors=result['n_errors'],
                                                                    time=central_time)
                else:
                    errors_per_error_type_in_bin[error_type]['n_errors'] += result['n_errors']

        return errors_per_error_type_in_bin.values()

    def get(self, request, format=None):
        try:
            self.validate_query_parameters(request)
        except ValueError as e:
            return Response(status=status.HTTP_406_NOT_ACCEPTABLE,
                            data='Error wrong format: %s' % (e,))
        except KeyError as e:
            return Response(status=status.HTTP_406_NOT_ACCEPTABLE,
                            data='Please specify all the correct parameters: %s' % (e,))

        response_payload = OrderedDict()

        response_payload['start_date'] = self.from_date
        response_payload['end_date'] = self.to_date
        response_payload['averaging_interval'] = self.averaging_interval

        errors_per_station = []
        errors_per_type = []
        n_bins = int(ceil((self.to_date - self.from_date) / self.averaging_interval))

        for i in range(n_bins):
            if self.station_group is 'A':
                station_group = None
            else:
                station_group = self.station_group
            errors_per_station.append(
                self.compute_errors_per_station(from_date=self.from_date + i * self.averaging_interval,
                                                to_date=self.from_date + (i + 1) * self.averaging_interval,
                                                central_time=self.from_date + (i + .5) * self.averaging_interval,
                                                station_group=station_group,
                                                test_type=self.test_type))
            errors_per_type.append(
                self.compute_errors_per_type(from_date=self.from_date + i * self.averaging_interval,
                                             to_date=self.from_date + (i + 1) * self.averaging_interval,
                                             central_time=self.from_date + (i + .5) * self.averaging_interval,
                                             station_group=station_group,
                                             test_type=self.test_type))

        response_payload['errors_per_station'] = errors_per_station
        response_payload['errors_per_type'] = errors_per_type

        return Response(status=status.HTTP_200_OK, data=response_payload)


class ControllerAllComponentErrorTypes(APIView):
    """
    Returns all distinct component errors
    """

    def get(self, request, format=None):
        data = [item['type'] for item in ComponentError.objects.values('type').distinct()]
        return Response(status=status.HTTP_200_OK, data=data)