Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
solver.py 10.53 KiB
import logging

import lofar.calibration.common.datacontainers as datacontainers
import numpy
from scipy.constants import c as LIGHT_SPEED

logger = logging.getLogger(__name__)

_FRQ_TO_PHASE = 2. * numpy.pi / LIGHT_SPEED


def _rotate_antenna_coordinates(dataset):
    """
    Rotate the antenna coordinates to be aligned to the station orientation
    :param dataset: input dataset
    :type dataset: datacontainers.HolographyDataset
    :return: the rotated coordinate system
    :rtype: numpy.matrix
    """
    station_position = numpy.array(dataset.target_station_position)
    antenna_field_position = numpy.array(dataset.antenna_field_position)

    antenna_position_offset = numpy.array([station_position - antenna_position for
                                           antenna_position in antenna_field_position])

    station_rotation_matrix = dataset.rotation_matrix
    return numpy.dot(station_rotation_matrix, antenna_position_offset.T).T


def _compute_expected_phase_delay(l, m, x, y, frequency):
    """
    Convert the antenna position (x,y) and the pointing cosines (l,m) into a phase delay for
    the given frequency
    :param l: l pointing cosine
    :param m: m pointing cosine
    :param x: x displacement of the antenna
    :param y: y displacement of the antenna
    :param frequency: frequency under consideration
    :return:
    """
    phase = (x * l + y * m) * _FRQ_TO_PHASE * frequency
    return phase


def _compute_pointing_matrices_per_station_frequency_beam(dataset, datatable, central_beam,
                                                          frequency):
    """
    Compute the pointing matrix of a given station, at a given frequency and for a specific beam.
    :param dataset: datatable's dataset
    :type dataset: datacontainers.HolographyDataset
    :param datatable: input data
    :type datatable: dict(dict(numpy.ndarray))
    :param frequency: frequency at which to compute the pointing matrix
    :type frequency: float
    :param central_beam: central beam
    :type central_beam: str
    :return: the pointing matrix
    :rtype: numpy.matrix
    """

    rotated_coordinates = _rotate_antenna_coordinates(dataset)
    n_antennas = len(dataset.antenna_field_position)
    one_over_n_antennas = 1. / float(n_antennas)
    n_beams = len(datatable)

    pointing_matrix = numpy.matrix(numpy.zeros((n_beams, n_antennas), dtype=complex))

    for i in range(n_beams):
        for j in range(n_antennas):
            l = datatable[central_beam]['mean']['l'][0] - datatable[str(i)]['mean']['l'][0]
            m = datatable[central_beam]['mean']['m'][0] - datatable[str(i)]['mean']['m'][0]
            x, y = rotated_coordinates[j, 0:2]

            phase = _compute_expected_phase_delay(l, m, x, y, frequency)
            pointing_matrix[i, j] = numpy.exp(-1.j * phase) * one_over_n_antennas

    return pointing_matrix


def _convert_pointing_matrix_to_real(matrix):
    """
    Convert a complex matrix in the equivalent real one.
    Ex.
    [re1 + j * im1  , re2 + j * im2 ;        [ re1, -im1,  re2, -im2]
                                        ~>    [ im1,  re1,  im2,  re2]
     re3 + j * im3  , re4 + j * im4 ]        [ re3, -im3,  re4, -im4]
                                             [ re3,  re3,  re$, -im4]

    :param matrix: input matrix complex valued
    :type matrix: numpy.matrix
    :return: output matrix real valued
    :rtype: numpy.matrix
    """
    matrix_real = numpy.matrix(numpy.zeros((matrix.shape[0] * 2, matrix.shape[1] * 2)))
    matrix_real[0::2, 0::2] = matrix.real
    matrix_real[0::2, 1::2] = -matrix.imag
    matrix_real[1::2, 0::2] = matrix.real
    matrix_real[1::2, 1::2] = matrix.imag
    return matrix_real


def _convert_visibilities_to_real(visibilities):
    """
    Convert a complex matrix in the equivalent real one.
    Ex.
    [re1 + j * im1  , re2 + j * im2]   ~>    [ re1, im1,  re2, im2]

    :param visibilities: array with the visibilities
    :type visibilities: numpy.matrix
    :return: the real valued matrix equivalent to the input one
    :rtype: numpy.matrix
    """
    visibilities_real = numpy.matrix(numpy.zeros(visibilities.shape[0] * 2)).T
    visibilities_real[0::2] = visibilities.real
    visibilities_real[1::2] = visibilities.imag

    return visibilities_real


def __convert_real_gains_to_complex(result):
    """
    Convert the real values of the gains back into the complex and computes
    the relative error
    :param result: gain computed with the complex to real transformation
    :return: the complex array with the gains
    """
    if not result['flag']:
        gains = result['gains']
        gains = gains[0::2] + 1.j * gains[1::2]
        result['gains'] = gains

        noise = result['relative_error']
        noise = numpy.sqrt(noise[0::2].A1 ** 2. + noise[1::2].A1 ** 2.)
        result["relative_error"] = noise
    return result


def _solve_gains_complex(visibilities, matrix, **kwargs):
    """
    To solve a complex linear system of equation it is possible to rewrite it in a equivalent
    real system.
    :param matrix:
    :param visibilities:
    :return:
    """
    matrix_real = _convert_pointing_matrix_to_real(matrix)

    visibilities_real = _convert_visibilities_to_real(visibilities)

    result_real = _solve_gains(visibilities_real, matrix_real, **kwargs)
    result = __convert_real_gains_to_complex(result_real)
    return result


def _invert_matrix_lstsqr(matrix, visibilities, rcond=None):
    try:
        gains, residual, _, _ = numpy.linalg.lstsq(matrix, visibilities, rcond=rcond)
    except ValueError as e:
        raise numpy.linalg.LinAlgError(e)
    return gains, residual


def _invert_matrix_direct(matrix, visibilities, rcond=None):
    try:

        gains = (matrix.T * matrix).I * matrix.T * visibilities
        # IS the same as : gains = matrix.I * visibilities
        residual = 0

    except ValueError as e:
        raise numpy.linalg.LinAlgError(e)
    return gains, residual


def _invert_matrix_mcmc(matrix, visibilities, **kwargs):
    def lnprob(parameters):
        gains, sigmas = parameters
        return -abs(matrix * gains - visibilities) ** 2 / sigmas ** 2

    ndim, nwalkers = 3, 100
    # pos = [result["x"] + 1e-4 * np.random.randn(ndim) for i in range(nwalkers)]
    raise NotImplementedError()


def _invert_matrix(matrix, visibilities, type='LSTSQR', **kwargs):
    """
    Invert the pointing to find the gains from the visibilities.
    It is possible to specify the type of solution method applied through the parameter type
    :param matrix: pointing matrix
    :param visibilities: visibilities for each pointing
    :param type: type of methodology used to invert the matrix
    :param kwargs: additional parameters for the solution method
    :return:
    """
    SOLUTION_METHODS = ['LSTSQR', 'MCMC', 'DIRECT']
    if type not in SOLUTION_METHODS:
        raise ValueError('wrong type of solution method specified. Alternatives are: %s'
                         % SOLUTION_METHODS)

    if type is 'LSTSQR':
        return _invert_matrix_lstsqr(matrix, visibilities, **kwargs)
    elif type is 'MCMC':
        return _invert_matrix_mcmc(matrix, visibilities, **kwargs)
    elif type is 'DIRECT':
        return _invert_matrix_direct(matrix, visibilities, **kwargs)


def _solve_gains(visibilities, matrix, **kwargs):
    """
    SOLVE THE EQUATION M * G = V FOR G
    where M is the pointing matrix
          G are the gains per antenna
          V are the visibilities
    :param visibilities: the visibility computed for each pointing
    :type visibilities: numpy.matrix
    :param matrix: the pointing matrix containing the phase delay for each pointing and antenna
    :type matrix: numpy.matrix
    :return: the gains for each antenna
    """
    try:
        gains, residual = _invert_matrix(matrix, visibilities, **kwargs)
        noise = abs(matrix * gains - visibilities) / abs(visibilities)

        return dict(gains=gains, residual=residual, relative_error=noise,
                    flag=numpy.array(False))
    except numpy.linalg.LinAlgError as e:
        logger.warning('error solving for the gains: %s', e)
        __empty = dict(gains=numpy.array(numpy.nan),
                       residual=numpy.array(numpy.nan),
                       relative_error=numpy.array(numpy.nan),
                       flag=numpy.array(True))
        return __empty


def _solve_gains_per_frequency(dataset, datatable, frequency, direct_complex=True, **kwargs):
    """
    SOLVE THE EQUATION M * G = V FOR G
    :param dataset:
    :type dataset: datacontainers.HolographyDataset
    :param datatable:
    :param frequency:
    :type frequency: float
    :return:
    """

    central_beam = dataset.central_beamlets[str(frequency)]
    matrix = _compute_pointing_matrices_per_station_frequency_beam(dataset,
                                                                   datatable,
                                                                   central_beam,
                                                                   frequency)
    n_beams = len(datatable)

    result = dict()
    #flags = numpy.array(
    #    [datatable[str(i)]['mean']['flag'] for i in range(n_beams)])
    #__empty = dict(gains=numpy.array(numpy.nan),
    #               residual=numpy.array(numpy.nan),
    #               relative_error=numpy.array(numpy.nan),
    #               flag=numpy.array(True))
    #is_frequency_flagged = flags

    for polarization in ['XX', 'XY', 'YX', 'YY']:
        visibilities = numpy.matrix(
            [datatable[str(i)]['mean'][polarization] for i in range(n_beams)])

        if direct_complex is True:
            result[polarization] = _solve_gains(visibilities, matrix, **kwargs)
        else:
            result[polarization] = _solve_gains_complex(visibilities, matrix, **kwargs)

    return result


def solve_gains_per_datatable(dataset, datatable, **kwargs):
    """
    Solve for the gains the given datatable
    :param dataset: dataset containing the specified datatable
    :type dataset: datacontainers.HolographyDataset
    :param datatable: table containing the data
    :type dataset: dict(dict(numpy.ndarray))
    :return: a nested dict containing the gain matrix per station, frequency, polarization
    :rtype: dict(dict(dict(numpy.ndarray)))
    """
    result = dict()

    for station in datatable:
        result_per_station = dict()
        result[station] = result_per_station
        data_per_station = datatable[station]
        for frequency in data_per_station:
            data_per_frequency = data_per_station[frequency]
            result_per_station[str(frequency)] = \
                _solve_gains_per_frequency(dataset, data_per_frequency, float(frequency), **kwargs)

    return result