import logging

import lofar.calibration.common.datacontainers as datacontainers
import numpy
from scipy.constants import c as LIGHT_SPEED

logger = logging.getLogger(__name__)

_FRQ_TO_PHASE = 2. * numpy.pi / LIGHT_SPEED


def _rotate_antenna_coordinates(dataset):
    """
    Rotate the antenna coordinates to be aligned to the station orientation
    :param dataset: input dataset
    :type dataset: datacontainers.HolographyDataset
    :return: the rotated coordinate system
    :rtype: numpy.matrix
    """
    station_position = numpy.array(dataset.target_station_position)
    antenna_field_position = numpy.array(dataset.antenna_field_position)

    antenna_position_offset = numpy.array([station_position - antenna_position for
                                           antenna_position in antenna_field_position])

    station_rotation_matrix = dataset.rotation_matrix
    return numpy.dot(station_rotation_matrix, antenna_position_offset.T).T


def _compute_expected_phase_delay(l, m, x, y, frequency):
    """
    Convert the antenna position (x,y) and the pointing cosines (l,m) into a phase delay for
    the given frequency
    :param l: l pointing cosine
    :param m: m pointing cosine
    :param x: x displacement of the antenna
    :param y: y displacement of the antenna
    :param frequency: frequency under consideration
    :return:
    """
    phase = (x * l + y * m) * _FRQ_TO_PHASE * frequency
    return phase


def _compute_pointing_matrices_per_station_frequency_beam(dataset, datatable, central_beam,
                                                          frequency):
    """
    Compute the pointing matrix of a given station, at a given frequency and for a specific beam.
    :param dataset: datatable's dataset
    :type dataset: datacontainers.HolographyDataset
    :param datatable: input data
    :type datatable: dict(dict(numpy.ndarray))
    :param frequency: frequency at which to compute the pointing matrix
    :type frequency: float
    :param central_beam: central beam
    :type central_beam: str
    :return: the pointing matrix
    :rtype: numpy.matrix
    """

    rotated_coordinates = _rotate_antenna_coordinates(dataset)
    n_antennas = len(dataset.antenna_field_position)
    one_over_n_antennas = 1. / float(n_antennas)
    n_beams = len(datatable)

    pointing_matrix = numpy.matrix(numpy.zeros((n_beams, n_antennas), dtype=complex))

    for i in range(n_beams):
        for j in range(n_antennas):
            l = datatable[central_beam]['mean']['l'][0] - datatable[str(i)]['mean']['l'][0]
            m = datatable[central_beam]['mean']['m'][0] - datatable[str(i)]['mean']['m'][0]
            x, y = rotated_coordinates[j, 0:2]

            phase = _compute_expected_phase_delay(l, m, x, y, frequency)
            pointing_matrix[i, j] = numpy.exp(-1.j * phase) * one_over_n_antennas

    return pointing_matrix


def _convert_pointing_matrix_to_real(matrix):
    """
    Convert a complex matrix in the equivalent real one.
    Ex.
    [re1 + j * im1  , re2 + j * im2 ;        [ re1, -im1,  re2, -im2]
                                        ~>    [ im1,  re1,  im2,  re2]
     re3 + j * im3  , re4 + j * im4 ]        [ re3, -im3,  re4, -im4]
                                             [ re3,  re3,  re$, -im4]

    :param matrix: input matrix complex valued
    :type matrix: numpy.matrix
    :return: output matrix real valued
    :rtype: numpy.matrix
    """
    matrix_real = numpy.matrix(numpy.zeros((matrix.shape[0] * 2, matrix.shape[1] * 2)))
    matrix_real[0::2, 0::2] = matrix.real
    matrix_real[0::2, 1::2] = -matrix.imag
    matrix_real[1::2, 0::2] = matrix.real
    matrix_real[1::2, 1::2] = matrix.imag
    return matrix_real


def _convert_visibilities_to_real(visibilities):
    """
    Convert a complex matrix in the equivalent real one.
    Ex.
    [re1 + j * im1  , re2 + j * im2]   ~>    [ re1, im1,  re2, im2]

    :param visibilities: array with the visibilities
    :type visibilities: numpy.matrix
    :return: the real valued matrix equivalent to the input one
    :rtype: numpy.matrix
    """
    visibilities_real = numpy.matrix(numpy.zeros(visibilities.shape[0] * 2)).T
    visibilities_real[0::2] = visibilities.real
    visibilities_real[1::2] = visibilities.imag

    return visibilities_real


def __convert_real_gains_to_complex(result):
    """
    Convert the real values of the gains back into the complex and computes
    the relative error
    :param result: gain computed with the complex to real transformation
    :return: the complex array with the gains
    """
    if not result['flag']:
        gains = result['gains']
        gains = gains[0::2] + 1.j * gains[1::2]
        result['gains'] = gains

        noise = result['relative_error']
        noise = numpy.sqrt(noise[0::2].A1 ** 2. + noise[1::2].A1 ** 2.)
        result["relative_error"] = noise
    return result


def _solve_gains_complex(visibilities, matrix, **kwargs):
    """
    To solve a complex linear system of equation it is possible to rewrite it in a equivalent
    real system.
    :param matrix:
    :param visibilities:
    :return:
    """
    matrix_real = _convert_pointing_matrix_to_real(matrix)

    visibilities_real = _convert_visibilities_to_real(visibilities)

    result_real = _solve_gains(visibilities_real, matrix_real, **kwargs)
    result = __convert_real_gains_to_complex(result_real)
    return result


def _invert_matrix_lstsqr(matrix, visibilities, rcond=None):
    try:
        gains, residual, _, _ = numpy.linalg.lstsq(matrix, visibilities, rcond=rcond)
    except ValueError as e:
        raise numpy.linalg.LinAlgError(e)
    return gains, residual


def _invert_matrix_direct(matrix, visibilities, rcond=None):
    try:

        gains = (matrix.T * matrix).I * matrix.T * visibilities
        # IS the same as : gains = matrix.I * visibilities
        residual = 0

    except ValueError as e:
        raise numpy.linalg.LinAlgError(e)
    return gains, residual


def _invert_matrix_mcmc(matrix, visibilities, **kwargs):
    def lnprob(parameters):
        gains, sigmas = parameters
        return -abs(matrix * gains - visibilities) ** 2 / sigmas ** 2

    ndim, nwalkers = 3, 100
    # pos = [result["x"] + 1e-4 * np.random.randn(ndim) for i in range(nwalkers)]
    raise NotImplementedError()


def _invert_matrix(matrix, visibilities, type='LSTSQR', **kwargs):
    """
    Invert the pointing to find the gains from the visibilities.
    It is possible to specify the type of solution method applied through the parameter type
    :param matrix: pointing matrix
    :param visibilities: visibilities for each pointing
    :param type: type of methodology used to invert the matrix
    :param kwargs: additional parameters for the solution method
    :return:
    """
    SOLUTION_METHODS = ['LSTSQR', 'MCMC', 'DIRECT']
    if type not in SOLUTION_METHODS:
        raise ValueError('wrong type of solution method specified. Alternatives are: %s'
                         % SOLUTION_METHODS)

    if type is 'LSTSQR':
        return _invert_matrix_lstsqr(matrix, visibilities, **kwargs)
    elif type is 'MCMC':
        return _invert_matrix_mcmc(matrix, visibilities, **kwargs)
    elif type is 'DIRECT':
        return _invert_matrix_direct(matrix, visibilities, **kwargs)


def _solve_gains(visibilities, matrix, **kwargs):
    """
    SOLVE THE EQUATION M * G = V FOR G
    where M is the pointing matrix
          G are the gains per antenna
          V are the visibilities
    :param visibilities: the visibility computed for each pointing
    :type visibilities: numpy.matrix
    :param matrix: the pointing matrix containing the phase delay for each pointing and antenna
    :type matrix: numpy.matrix
    :return: the gains for each antenna
    """
    try:
        gains, residual = _invert_matrix(matrix, visibilities, **kwargs)
        noise = abs(matrix * gains - visibilities) / abs(visibilities)

        return dict(gains=gains, residual=residual, relative_error=noise,
                    flag=numpy.array(False))
    except numpy.linalg.LinAlgError as e:
        logger.warning('error solving for the gains: %s', e)
        __empty = dict(gains=numpy.array(numpy.nan),
                       residual=numpy.array(numpy.nan),
                       relative_error=numpy.array(numpy.nan),
                       flag=numpy.array(True))
        return __empty


def _solve_gains_per_frequency(dataset, datatable, frequency, direct_complex=True, **kwargs):
    """
    SOLVE THE EQUATION M * G = V FOR G
    :param dataset:
    :type dataset: datacontainers.HolographyDataset
    :param datatable:
    :param frequency:
    :type frequency: float
    :return:
    """

    central_beam = dataset.central_beamlets[str(frequency)]
    matrix = _compute_pointing_matrices_per_station_frequency_beam(dataset,
                                                                   datatable,
                                                                   central_beam,
                                                                   frequency)
    n_beams = len(datatable)

    result = dict()
    #flags = numpy.array(
    #    [datatable[str(i)]['mean']['flag'] for i in range(n_beams)])
    #__empty = dict(gains=numpy.array(numpy.nan),
    #               residual=numpy.array(numpy.nan),
    #               relative_error=numpy.array(numpy.nan),
    #               flag=numpy.array(True))
    #is_frequency_flagged = flags

    for polarization in ['XX', 'XY', 'YX', 'YY']:
        visibilities = numpy.matrix(
            [datatable[str(i)]['mean'][polarization] for i in range(n_beams)])

        if direct_complex is True:
            result[polarization] = _solve_gains(visibilities, matrix, **kwargs)
        else:
            result[polarization] = _solve_gains_complex(visibilities, matrix, **kwargs)

    return result


def solve_gains_per_datatable(dataset, datatable, **kwargs):
    """
    Solve for the gains the given datatable
    :param dataset: dataset containing the specified datatable
    :type dataset: datacontainers.HolographyDataset
    :param datatable: table containing the data
    :type dataset: dict(dict(numpy.ndarray))
    :return: a nested dict containing the gain matrix per station, frequency, polarization
    :rtype: dict(dict(dict(numpy.ndarray)))
    """
    result = dict()

    for station in datatable:
        result_per_station = dict()
        result[station] = result_per_station
        data_per_station = datatable[station]
        for frequency in data_per_station:
            data_per_frequency = data_per_station[frequency]
            result_per_station[str(frequency)] = \
                _solve_gains_per_frequency(dataset, data_per_frequency, float(frequency), **kwargs)

    return result