#!/usr/bin/env python
#
# Copyright (C) 2007
# ASTRON (Netherlands Institute for Radio Astronomy)
# P.O.Box 2, 7990 AA Dwingeloo, The Netherlands
#
# This file is part of the LOFAR software suite.
# The LOFAR software suite is free software: you can redistribute it and/or
# modify it under the terms of the GNU General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# The LOFAR software suite is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with the LOFAR software suite. If not, see <http://www.gnu.org/licenses/>.
#
# $Id$

import sys
import lofar.parmdb as parmdb
import copy
import math
import numpy

import matplotlib
from matplotlib.backends.backend_qt4agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.backends.backend_qt4agg import NavigationToolbar2QTAgg as NavigationToolbar
from matplotlib.figure import Figure
from matplotlib.font_manager import FontProperties

from PyQt4.QtCore import *
from PyQt4.QtGui import *

__styles = ["%s%s" % (x, y) for y in ["-", ":"] for x in ["b", "g", "r", "c",
    "m", "y", "k"]]

def contains(container, item):
    try:
        return container.index(item) >= 0
    except ValueError:
        return False

def common_domain(parms):
    if len(parms) == 0:
        return None

    domain = [-1e30, 1e30, -1e30, 1e30]
    for parm in parms:
        tmp = parm.domain()
        domain = [max(domain[0], tmp[0]), min(domain[1], tmp[1]), max(domain[2], tmp[2]), min(domain, tmp[3])]

    if domain[0] >= domain[1] or domain[2] >= domain[3]:
        return None

    return domain

def parseFloat(text, lower, upper):
    value = float(text)
    if value < lower:
        value = lower
    elif value > upper:
        value = upper

    return value

def unwrap(phase, tol=0.25, delta_tol=0.25):
    """
    Unwrap phase by restricting phase[n] to fall within a range [-tol, tol]
    around phase[n - 1].

    If this is impossible, the closest phase (modulo 2*pi) is used and tol is
    increased by delta_tol (tol is capped at pi).
    """

    assert(tol < math.pi)

    # Allocate result.
    out = numpy.zeros(phase.shape)

    # Effective tolerance.
    eff_tol = tol

    ref = phase[0]
    for i in range(0, len(phase)):
        delta = math.fmod(phase[i] - ref, 2.0 * math.pi)

        if delta < -math.pi:
            delta += 2.0 * math.pi
        elif delta > math.pi:
            delta -= 2.0 * math.pi

        out[i] = ref + delta

        if abs(delta) <= eff_tol:
            # Update reference phase and reset effective tolerance.
            ref = out[i]
            eff_tol = tol
        elif eff_tol < math.pi:
            # Increase effective tolerance.
            eff_tol += delta_tol * tol
            if eff_tol > math.pi:
                eff_tol = math.pi

    return out

def unwrap_windowed(phase, window_size=5):
    """
    Unwrap phase by estimating the trend of the phase signal.
    """

    # Allocate result.
    out = numpy.zeros(phase.shape)

    windowl = numpy.array([math.fmod(phase[0], 2.0 * math.pi)] * window_size)

    delta = math.fmod(phase[1] - windowl[0], 2.0 * math.pi)
    if delta < -math.pi:
        delta += 2.0 * math.pi
    elif delta > math.pi:
        delta -= 2.0 * math.pi
    windowu = numpy.array([windowl[0] + delta] * window_size)

    out[0] = windowl[0]
    out[1] = windowu[0]

    meanl = windowl.mean()
    meanu = windowu.mean()
    slope = (meanu - meanl) / float(window_size)

    for i in range(2, len(phase)):
        ref = meanu + (1.0 + (float(window_size) - 1.0) / 2.0) * slope
        delta = math.fmod(phase[i] - ref, 2.0 * math.pi)

        if delta < -math.pi:
            delta += 2.0 * math.pi
        elif delta > math.pi:
            delta -= 2.0 * math.pi

        out[i] = ref + delta

        windowl[:-1] = windowl[1:]
        windowl[-1] = windowu[0]
        windowu[:-1] = windowu[1:]
        windowu[-1] = out[i]

        meanl = windowl.mean()
        meanu = windowu.mean()
        slope = (meanu - meanl) / float(window_size)

    return out

def normalize(phase):
    """
    Normalize phase to the range [-pi, pi].
    """

    # Convert to range [-2*pi, 2*pi].
    out = numpy.fmod(phase, 2.0 * numpy.pi)

    # Convert to range [-pi, pi]
    out[out < -numpy.pi] += 2.0 * numpy.pi
    out[out > numpy.pi] -= 2.0 * numpy.pi

    return out

def plot(fig, y, x=None, clf=True, sub=None, scatter=False, stack=False,
    sep=5.0, sep_abs=False, labels=None, show_legend=False, title=None,
    xlabel=None, ylabel=None):
    """
    Plot a list of signals.

    If 'fig' is equal to None, a new figure will be created. Otherwise, the
    specified figure number is used. The 'sub' argument can be used to create
    subplots.

    The 'scatter' argument selects between scatter and line plots.

    The 'stack', 'sep', and 'sep_abs' arguments can be used to control placement
    of the plots in the list. If 'stack' is set to True, each plot will be
    offset by the mean plus sep times the standard deviation of the previous
    plot. If 'sep_abs' is set to True, 'sep' is used as is.

    The 'labels' argument can be set to a list of labels and 'show_legend' can
    be set to True to show a legend inside the plot.

    The figure number of the figure used to plot in is returned.
    """
    global __styles

    if clf:
        fig.clf()

    if sub is not None:
        fig.add_subplot(sub)

    axes = fig.gca()
    if not title is None:
        axes.set_title(title)
    if not xlabel is None:
        axes.set_xlabel(xlabel)
    if not ylabel is None:
        axes.set_ylabel(ylabel)

    if x is None:
        x = [range(len(yi)) for yi in y]

    offset = 0.0
    for i in range(0,len(y)):
        if labels is None:
            if scatter:
                axes.scatter(x[i], y[i] + offset, edgecolors="None",
                    c=__styles[i % len(__styles)][0], marker="o")
            else:
                axes.plot(x[i], y[i] + offset, __styles[i % len(__styles)])
        else:
            if scatter:
                axes.scatter(x[i], y[i] + offset, edgecolors="None",
                    c=__styles[i % len(__styles)][0], marker="o",
                    label=labels[i])
            else:
                axes.plot(x[i], y[i] + offset, __styles[i % len(__styles)],
                    label=labels[i])

        if stack:
            if sep_abs:
                offset += sep
            else:
                offset += y[i].mean() + sep * y[i].std()

    if not labels is None and show_legend:
        axes.legend(prop=FontProperties(size="x-small"), markerscale=0.5)

class Parm:
    def __init__(self, db, name, elements=None, isPolar=False):
        self._db = db
        self._name = name
        self._elements = elements
        self._isPolar = isPolar
        self._value = None
        self._value_domain = None
        self._value_resolution = None

        self._readDomain()

    def name(self):
        return self._name

    def isPolar(self):
        return self._isPolar

    def empty(self):
        return self._empty

    def domain(self):
        return self._domain

    def value(self, domain=None, resolution=None, asPolar=True, unwrap_phase=False):
        if self.empty():
            assert(False)
            return (numpy.zeros((1,1)), numpy.zeros((1,1)))

        if self._value is None or self._value_domain != domain or self._value_resolution != resolution:
            self._readValue(domain, resolution)

        if asPolar:
            if self.isPolar():
                ampl = self._value[0]
                phase = normalize(self._value[1])
            else:
                ampl = numpy.sqrt(numpy.power(self._value[0], 2) + numpy.power(self._value[1], 2))
                phase = numpy.arctan2(self._value[1], self._value[0])

            if unwrap_phase:
                for i in range(0, phase.shape[1]):
                    phase[:, i] = unwrap(phase[:, i])

            return (ampl, phase)

        if not self.isPolar():
            re = self._value[0]
            im = self._value[1]
        else:
            re = self._value[0] * numpy.cos(self._value[1])
            im = self._value[0] * numpy.sin(self._value[1])

        return (re, im)

    def _readDomain(self):
        if self._elements is None:
            self._domain = self._db.getRange(self.name())
        else:
            domain_el0 = self._db.getRange(self._elements[0])
            domain_el1 = self._db.getRange(self._elements[1])
            self._domain = [max(domain_el0[0], domain_el1[0]), min(domain_el0[1], domain_el1[1]), max(domain_el0[2], domain_el1[2]), min(domain_el0[3], domain_el1[3])]

        self._empty = (self._domain[0] >= self._domain[1]) or (self._domain[2] >= self._domain[3])

    def _readValue(self, domain=None, resolution=None):
#        print "fetching:", self.name()

        if self._elements is None:
            value = numpy.array(self.__fetch_value(self.name(), domain, resolution))
            self._value = (value, numpy.zeros(value.shape))
        else:
            el0 = numpy.array(self.__fetch_value(self._elements[0], domain, resolution))
            el1 = numpy.array(self.__fetch_value(self._elements[1], domain, resolution))
            self._value = (el0, el1)

        self._value_domain = domain
        self._value_resolution = resolution

    def __fetch_value(self, name, domain=None, resolution=None):
        if domain is None:
            tmp = self._db.getValuesGrid(name)[name]
        else:
            if resolution is None:
                tmp = self._db.getValues(name, domain[0], domain[1], domain[2], domain[3])[name]
            else:
                tmp = self._db.getValuesStep(name, domain[0], domain[1], resolution[0], domain[2], domain[3], resolution[1])[name]

        if type(tmp) is dict:
            return tmp["values"]

        # Old parmdb interface.
        return tmp

class PlotWindow(QFrame):
    def __init__(self, parms, resolution=None, parent=None):
        QFrame.__init__(self, parent)

        self.parms = parms
        self.resolution = resolution

        self.fig = Figure((5, 4), dpi=100)

        self.canvas = FigureCanvas(self.fig)
        self.canvas.setParent(self)

        self.toolbar = NavigationToolbar(self.canvas, self)

        self.axis = 0
        self.index = 0
        axisSelector = QComboBox()
        axisSelector.addItem("Frequency")
        axisSelector.addItem("Time")
        self.connect(axisSelector, SIGNAL('activated(int)'), self.handle_axis)

        self.show_legend = False
        legendCheck = QCheckBox("Legend")
        self.connect(legendCheck, SIGNAL('stateChanged(int)'), self.handle_legend)

        self.polar = True
        polarCheck = QCheckBox("Polar")
        polarCheck.setChecked(True)
        self.connect(polarCheck, SIGNAL('stateChanged(int)'), self.handle_polar)

        self.unwrap_phase = False
        unwrapCheck = QCheckBox("Unwrap phase")
        self.connect(unwrapCheck, SIGNAL('stateChanged(int)'), self.handle_unwrap)

#        self.slider = QSlider(Qt.Horizontal)
#        self.slider.setMinimum(0)
#        self.slider.setMaximum(159)
#        self.connect(self.slider, SIGNAL('sliderReleased()'), self.handle_slider)

        self.spinner = QSpinBox()
        self.connect(self.spinner, SIGNAL('valueChanged(int)'), self.handle_spinner)
        hbox = QHBoxLayout()
        hbox.addWidget(axisSelector)
        hbox.addWidget(self.spinner)
        hbox.addWidget(legendCheck)
        hbox.addWidget(polarCheck)
        hbox.addWidget(unwrapCheck)
        hbox.addStretch(1)
        hbox.addWidget(self.toolbar)

        layout = QVBoxLayout()
        layout.addWidget(self.canvas, 1)
        layout.addLayout(hbox);
        self.setLayout(layout)

        self.domain = common_domain(self.parms)

        self.shape = (1, 1)
        if not self.domain is None:
            self.shape = (self.parms[0].value(self.domain, self.resolution)[0].shape)
            assert(len(self.shape) == 2)

        self.spinner.setRange(0, self.shape[1 - self.axis] - 1)

        self.plot()

    def plot(self):
        el0 = []
        el1 = []
        labels = []

        if not self.domain is None:
            for parm in self.parms:
                value = parm.value(self.domain, self.resolution, self.polar, self.unwrap_phase)

                if value[0].shape != self.shape or value[1].shape != self.shape:
                    print "warning: non-consistent result shape; will skip parameter:", parm.name()
                    continue

                if self.axis == 0:
                    el0.append(value[0][:, self.index])
                    el1.append(value[1][:, self.index])
                else:
                    el0.append(value[0][self.index, :])
                    el1.append(value[1][self.index, :])
                labels.append(parm.name())

        legend = self.show_legend and len(labels) > 0
        xlabel = ["Time (sample)", "Freq (sample)"][self.axis]
        if self.polar:
            plot(self.fig, el0, sub="211", labels=labels, show_legend=legend, xlabel=xlabel, ylabel="Amplitude")
            plot(self.fig, el1, clf=False, sub="212", stack=True, scatter=True, labels=labels, show_legend=legend, xlabel=xlabel, ylabel="Phase (rad)")
        else:
            plot(self.fig, el0, sub="211", labels=labels, show_legend=legend, xlabel=xlabel, ylabel="Real")
            plot(self.fig, el1, clf=False, sub="212", labels=labels, show_legend=legend, xlabel=xlabel, ylabel="Imaginary")

        # Set x-axis scale in number of samples.
        for ax in self.fig.axes:
            ax.set_xlim(0, self.shape[self.axis] - 1)

        self.canvas.draw()

    def handle_spinner(self, index):
        self.index = index
        self.plot()

    def handle_axis(self, axis):
        if axis != self.axis:
            self.axis = axis
            self.spinner.setRange(0, self.shape[1 - self.axis] - 1)
            self.spinner.setValue(0)
            self.plot()

    def handle_legend(self, state):
        self.show_legend = (state == 2)
        self.plot()

    def handle_unwrap(self, state):
        self.unwrap_phase = (state == 2)
        self.plot()

    def handle_polar(self, state):
        self.polar = (state == 2)
        self.plot()

class MainWindow(QFrame):
    def __init__(self, db):
        QFrame.__init__(self)
        self.db = db
        self.figures = []
        self.parms = []

#        self.setWindowTitle("parmdbplot")

        layout = QVBoxLayout()

        self.list = QListWidget()
        self.list.setSelectionMode(QAbstractItemView.ExtendedSelection)
        layout.addWidget(self.list, 1)

        self.useResolution = True
        checkResolution = QCheckBox("Use resolution")
        checkResolution.setChecked(True)
        self.connect(checkResolution, SIGNAL('stateChanged(int)'), self.handle_resolution)

        self.resolution = [QLineEdit(), QLineEdit()]
#        validator = QDoubleValidator(self.resolution[0])
#        validator.setRange(1.0, 2.0)
#        self.resolution[0].setValidator(validator)
        self.resolution[0].setAlignment(Qt.AlignRight)
        self.resolution[1].setAlignment(Qt.AlignRight)

        hbox = QHBoxLayout()
        hbox.addWidget(checkResolution)
        hbox.addWidget(self.resolution[0])
        hbox.addWidget(QLabel("Hz"))
        hbox.addWidget(self.resolution[1])
        hbox.addWidget(QLabel("s"))
        layout.addLayout(hbox)

        self.button = QPushButton("Plot")
        layout.addWidget(self.button)
        self.connect(self.button, SIGNAL('clicked()'), self.handle_plot)

        self.button = QPushButton("Close all figures")
        layout.addWidget(self.button)
        self.connect(self.button, SIGNAL('clicked()'), self.handle_close)

        self.setLayout(layout)

        self.populate()

    def populate(self):
        for parm in self.db.getNames():
            split = parm.split(":")

            if contains(split, "Real") or contains(split, "Imag"):
                if contains(split, "Real"):
                    idx = split.index("Real")
                    split[idx] = "Imag"
                    elements = [parm, ":".join(split)]
                else:
                    idx = split.index("Imag")
                    split[idx] = "Real"
                    elements = [":".join(split), parm]

                split.pop(idx)
                name = ":".join(split)

                found = False
                for i in range(len(self.parms)):
                    if self.parms[i].name() == name and not self.parms[i].isPolar():
                        found = True
                        break

                if not found:
                    self.parms.append(Parm(self.db, name, elements))
            elif contains(split, "Ampl") or contains(split, "Phase"):
                if contains(split, "Ampl"):
                    idx = split.index("Ampl")
                    split[idx] = "Phase"
                    elements = [parm, ":".join(split)]
                else:
                    idx = split.index("Phase")
                    split[idx] = "Ampl"
                    elements = [":".join(split), parm]

                split.pop(idx)
                name = ":".join(split)

                found = False
                for i in range(len(self.parms)):
                    if self.parms[i].name() == name and self.parms[i].isPolar():
                        found = True
                        break

                if not found:
                    self.parms.append(Parm(self.db, name, elements, True))
            else:
                self.parms.append(Parm(self.db, parm))

        self.parms = [parm for parm in self.parms if not parm.empty()]
        self.parms.sort(cmp=lambda x, y: cmp(x.name(), y.name()))

        domain = common_domain(self.parms)
        if not domain is None:
            self.resolution[0].setText("%.6f" % ((domain[1] - domain[0]) / 100.0))
            self.resolution[1].setText("%.6f" % ((domain[3] - domain[2]) / 100.0))

        for parm in self.parms:
            name = parm.name()
            if parm.isPolar():
                name = "%s (polar)" % name

            QListWidgetItem(name, self.list)

    def handle_resolution(self, state):
        self.useResolution = (state == 2)

    def handle_plot(self):
        parms = []
        tmp = self.list.selectedItems()
        tmp.sort()
        for item in tmp:
            idx = self.list.row(item)
            parms.append(copy.copy(self.parms[idx]))

        resolution = None
        domain = common_domain(parms)

        if domain is not None and self.useResolution:
            resolution = [parseFloat(self.resolution[0].text(), 1.0, domain[1] - domain[0]),
                parseFloat(self.resolution[1].text(), 1.0, domain[3] - domain[2])]

        self.figures.append(PlotWindow(parms, resolution))
        self.figures[-1].show()

    def handle_close(self):
        for fig in self.figures:
            fig.close()

if __name__ == "__main__":
    if len(sys.argv) <= 1 or sys.argv[1] == "--help":
        print "usage: parmdbplot.py <parmdb>"
        sys.exit(1)

    db = parmdb.parmdb(sys.argv[1])

    app = QApplication(sys.argv)
    window = MainWindow(db)
    window.show()

#    app.connect(app, SIGNAL('lastWindowClosed()'), app, SLOT('quit()'))
    app.exec_()