import logging

import numpy

logger = logging.getLogger(__name__)


def extract_crosscorrelation_matrices_from_HDS_datatable(datatable):
    new_shape = (2, 2, datatable.shape[0])
    new_array = numpy.zeros(shape=new_shape, dtype=numpy.complex)
    new_array[0, 0, :] = datatable['XX']
    new_array[0, 1, :] = datatable['XY']
    new_array[1, 0, :] = datatable['YX']
    new_array[1, 1, :] = datatable['YY']

    return new_array


def invert_crosscorrelation_matrices(cross_correlation_matrices):
    """
    Invert the matrices along the last axis
    :param cross_correlation_matrices: matrices to invert
    :type cross_correlation_matrices: numpy.ndarray
    :return:
    """
    assert cross_correlation_matrices.ndim >= 3

    return numpy.linalg.inv(cross_correlation_matrices)