# Cobbled together from lofarimaging and lofty
from collections import defaultdict
import numpy as np

from tango import DeviceProxy, DevSource
from lofarantpos.db import LofarAntennaDatabase
import numexpr as ne
import matplotlib.pyplot as plt

# Configurations for HBA observations with a single dipole activated per tile.
GENERIC_INT_201512 = [0, 5, 3, 1, 8, 3, 12, 15, 10, 13, 11, 5, 12, 12, 5, 2, 10, 8, 0, 3, 5, 1, 4, 0, 11, 6, 2, 4, 9,
                      14, 15, 3, 7, 5, 13, 15, 5, 6, 5, 12, 15, 7, 1, 1, 14, 9, 4, 9, 3, 9, 3, 13, 7, 14, 7, 14, 2, 8,
                      8, 0, 1, 4, 2, 2, 12, 15, 5, 7, 6, 10, 12, 3, 3, 12, 7, 4, 6, 0, 5, 9, 1, 10, 10, 11, 5, 11, 7, 9,
                      7, 6, 4, 4, 15, 4, 1, 15]
GENERIC_CORE_201512 = [0, 10, 4, 3, 14, 0, 5, 5, 3, 13, 10, 3, 12, 2, 7, 15, 6, 14, 7, 5, 7, 9, 0, 15, 0, 10, 4, 3, 14,
                       0, 5, 5, 3, 13, 10, 3, 12, 2, 7, 15, 6, 14, 7, 5, 7, 9, 0, 15]
GENERIC_REMOTE_201512 = [0, 13, 12, 4, 11, 11, 7, 8, 2, 7, 11, 2, 10, 2, 6, 3, 8, 3, 1, 7, 1, 15, 13, 1, 11, 1, 12, 7,
                         10, 15, 8, 2, 12, 13, 9, 13, 4, 5, 5, 12, 5, 5, 9, 11, 15, 12, 2, 15]
SPEED_OF_LIGHT = 299792458.0

def get_station_type(station_name: str) -> str:
    """
    Get the station type, one of 'intl', 'core' or 'remote'

    Args:
        station_name: Station name, e.g. "DE603LBA" or just "DE603"

    Returns:
        str: station type, one of 'intl', 'core' or 'remote'

    Example:
        >>> get_station_type("DE603")
        'intl'
    """
    if station_name[0] == "C":
        return "core"
    elif station_name[0] == "R" or station_name[:5] == "PL611":
        return "remote"
    else:
        return "intl"


def get_station_pqr(full_station_name: str, activation_pattern: bool | str = False):
    """
    Get PQR coordinates for the relevant subset of antennas in a station.

    Args:
        station_name: Station name, e.g. 'DE603LBA' or 'DE603'
        rcu_mode: RCU mode (0 - 6, can be string)
        db: instance of LofarAntennaDatabase from lofarantpos

    Example:
        >>> from lofarantpos.db import LofarAntennaDatabase
        >>> db = LofarAntennaDatabase()
        >>> pqr = get_station_pqr("DE603", "outer", db)
        >>> pqr.shape
        (96, 3)
        >>> pqr[0, 0]
        1.7434713

        >>> pqr = get_station_pqr("LV614", "5", db)
        >>> pqr.shape
        (96, 3)
    """
    db = LofarAntennaDatabase()
    station_type = get_station_type(full_station_name)

    if 'LBA' in full_station_name or not activation_pattern:
        station_pqr = db.antenna_pqr(full_station_name)
    elif 'HBA' in full_station_name:
        antenna_set = activation_pattern
        selected_dipole_config = {
            'intl': GENERIC_INT_201512, 'remote': GENERIC_REMOTE_201512, 'core': GENERIC_CORE_201512
        }
        selected_dipoles = selected_dipole_config[station_type] + \
            np.arange(len(selected_dipole_config[station_type])) * 16
        print(selected_dipoles.shape, db.hba_dipole_pqr(full_station_name).shape, selected_dipoles, db.hba_dipole_pqr(full_station_name))
        if antenna_set == "HBA_SINGLE":
            station_pqr = db.hba_dipole_pqr(full_station_name)[selected_dipoles]
        elif antenna_set == "HBA0_SINGLE":
            station_pqr = db.hba_dipole_pqr(full_station_name)[selected_dipoles[:24]]
        elif antenna_set == "HBA1_SINGLE":
            station_pqr = db.hba_dipole_pqr(full_station_name)[selected_dipoles[24:] - 16 * 24]
        else:
            raise RuntimeError("AAAAAAAAAAAAAA")
    else:
        raise RuntimeError("Station name did not contain LBA or HBA, could not load antenna positions")

    return station_pqr.astype('float32')

def nearfield_imager(visibilities, baseline_indices, freqs, npix_p, npix_q, extent, station_pqr, height=1.5,
                     max_memory_mb=200):
    """
    Nearfield imager

    Args:
        visibilities: Numpy array with visibilities, shape [num_visibilities x num_frequencies]
        baseline_indices: List with tuples of antenna numbers in visibilities, shape [2 x num_visibilities]
        freqs: List of frequencies
        npix_p: Number of pixels in p-direction
        npix_q: Number of pixels in q-direction
        extent: Extent (in m) that the image should span
        station_pqr: PQR coordinates of stations
        height: Height of image in metre
        max_memory_mb: Maximum amount of memory to use for the biggest array. Higher may improve performance.

    Returns:
        np.array(complex): Complex valued array of shape [npix_p, npix_q]
    """
    z = height
    x = np.linspace(extent[0], extent[1], npix_p)
    y = np.linspace(extent[2], extent[3], npix_q)

    posx, posy = np.meshgrid(x, y)
    posxyz = np.transpose(np.array([posx, posy, z * np.ones_like(posx)]), [1, 2, 0])

    diff_vectors = (station_pqr[:, None, None, :] - posxyz[None, :, :, :])
    distances = np.linalg.norm(diff_vectors, axis=3)

    vis_chunksize = max_memory_mb * 1024 * 1024 // (8 * npix_p * npix_q)

    bl_diff = np.zeros((vis_chunksize, npix_q, npix_p), dtype=np.float64)
    img = np.zeros((npix_q, npix_p), dtype=np.complex128)
    for vis_chunkstart in range(0, len(baseline_indices), vis_chunksize):
        vis_chunkend = min(vis_chunkstart + vis_chunksize, baseline_indices.shape[0])
        # For the last chunk, bl_diff_chunk is a bit smaller than bl_diff
        bl_diff_chunk = bl_diff[:vis_chunkend - vis_chunkstart, :]
        np.add(distances[baseline_indices[vis_chunkstart:vis_chunkend, 0]],
               -distances[baseline_indices[vis_chunkstart:vis_chunkend, 1]], out=bl_diff_chunk)

        j2pi = 1j * 2 * np.pi
        for ifreq, freq in enumerate(freqs):
            v = visibilities[vis_chunkstart:vis_chunkend, ifreq][:, None, None]
            lamb = SPEED_OF_LIGHT / freq

            # v[:,np.newaxis,np.newaxis]*np.exp(-2j*np.pi*freq/SPEED_OF_LIGHT*groundbase_pixels[:,:,:]/SPEED_OF_LIGHT)
            # groundbase_pixels=nvis x npix x npix
            np.add(img, np.sum(ne.evaluate("v * exp(j2pi * bl_diff_chunk / lamb)"), axis=0), out=img)
    img /= len(freqs) * len(baseline_indices)

    return img

def get_xst_statistics(antennafield: str):
    device = DeviceProxy(f"stat/xst/{antennafield}")
    device.set_source(DevSource.DEV)

    subband = device.xst_subband_r
    dataset = device.xst_real_r.astype(np.complex64) + 1.0j * device.xst_imag_r.astype(np.complex64)

    frequency = DeviceProxy(f"stat/sdp/{antennafield}").subband_frequency_r[0][subband]

    return dataset, frequency


def wrap_nearfield_imager(antennafield: str, activation_pattern: bool = False, heights_m: float | list[float] = 1.5, npix_p: int = 96, npix_q: int = 96, extent_m: tuple[float, float] = (-100.0, 100.0, -100.0,100.0) , max_memory_mb: float = 200):
    """
    Wrap nearfield imager for single/multi-subband captures.
    """
    if not isinstance(heights_m, list):
        if not isinstance(heights_m, float):
            raise TypeError(f"Expected a float or list of floats for subbannds, got {heights_m=}")
        heights_m = [heights_m]

    results = defaultdict(lambda: dict())
    dataset, frequency = get_xst_statistics(antennafield)
    baseline_indices = np.array(np.tril_indices_from(dataset))
    dataset_selected = dataset[baseline_indices].ravel()[:, np.newaxis]

    manager_device = DeviceProxy("stat/stationmanager/1")
    station_pqr = get_station_pqr(f"{manager_device.station_name_r}{antennafield}".upper(), activation_pattern = activation_pattern)

    for height in heights_m:
        results[height] = nearfield_imager(dataset_selected, baseline_indices, freqs = [frequency], npix_p = npix_p, npix_q = npix_q, extent = extent_m, station_pqr = station_pqr, height = height, max_memory_mb = max_memory_mb)

    for key, result in results.items():
        plt.figure()
        plt.title(f"{antennafield} @ {str(key)}")
        plt.imshow(np.abs(result))

    plt.show()
    return results


if __name__ == '__main__':
    wrap_nearfield_imager("LBA", extent_m = [-50, 50, -50, 50], heights_m = [0.5, 1.0, 1.5, 10.0])
    wrap_nearfield_imager("HBA1", extent_m = [-50, 50, -50, 50], heights_m = [0.5, 1.0, 1.5, 10.0])
    wrap_nearfield_imager("HBA1", activation_pattern = "HBA1_SINGLE", extent_m = [-50, 50, -50, 50], heights_m = [0.5, 1.0, 1.5, 10.0])
    wrap_nearfield_imager("HBA0", extent_m = [-50, 50, -50, 50], heights_m = [0.5, 1.0, 1.5, 10.0])
    wrap_nearfield_imager("HBA0", activation_pattern = "HBA0_SINGLE", extent_m = [-50, 50, -50, 50], heights_m = [0.5, 1.0, 1.5, 10.0])