# -*- coding: utf-8 -*-
"""
Created on Mon May 23 11:14:44 2016

@author: shomroni
"""


from matplotlib.collections import LineCollection
from matplotlib.lines import Line2D

import matplotlib.pyplot as plt
import numpy as np



def combine_plots(figs):
    """
    Combine data from a list of figures into a single figure.

    Useful for first creating plots from several runs and then merging them.
    """
    f = plt.figure(99)
    f.clf()

    ax = f.add_subplot(111)

    for i in figs:

        l = plt.figure(i).axes[0].lines[0]

        #ax.add_line(l)
        ax.plot(l.get_xdata(), l.get_ydata(),
                    label=l.get_label(), color=l.get_color(),
                    linestyle=l.get_linestyle(), marker=l.get_marker())

    ax.legend()

    f.show()

    return f



def plotyy(*args, **kwargs):
    """
    plotyy([ax,] x, y1, y2)
    """
    if isinstance(args[0], plt.Axes):
        ax1, x, y1, y2 = args
    else:
        ax1 = plt.gca()
        x, y1, y2 = args

    xlabel = kwargs.get('xlabel')
    y1label = kwargs.get('y1label')
    y2label = kwargs.get('y2label')
    label1 = kwargs.get('label1')
    label2 = kwargs.get('label2')

    ax2 = ax1.twinx()

    ax1.plot(x, y1, 'bo-', label=label1)
    ax1.tick_params(axis='y', colors='b')

    ax2.plot(x, y2, 'ro-', label=label2)
    ax2.tick_params(axis='y', colors='r')

    if xlabel:
        ax1.set_xlabel(xlabel)

    if y1label:
        ax1.set_ylabel(y1label, color='b')

    if y2label:
        ax2.set_ylabel(y2label, color='r')

    return ax1, ax2




def waterfall(xdata, ydata, labels=None, **kwargs):
    """
    For colormaps see:
        http://matplotlib.org/examples/color/colormaps_reference.html
    """
    colormap = kwargs.pop('colormap', None)

    if colormap is None:
        colormap = 'viridis'

    n = ydata.shape[1]

    zs = np.arange(n)[::-1]     # reverse list so drawing from back to front

    verts = [np.array([xdata, ydata[:,z]]).transpose() for z in zs]

    poly = LineCollection(verts, colors=plt.get_cmap(colormap)(np.linspace(0,1,n)))

    #fig = plt.figure()
    ax = plt.gcf().gca(projection='3d', **kwargs)
    ax.add_collection3d(poly, zs=zs, zdir='y')
    ax.set_xlim3d(xdata[0], xdata[-1])
    ax.set_ylim3d(-1, n+1)
    ax.set_zlim3d(0, ydata.max())

    # https://stackoverflow.com/questions/19877666/add-legends-to-linecollection-plot
    if labels is not None:
        proxies = [Line2D([0, 1], [0, 1], color=plt.get_cmap(colormap)(np.linspace(0,1,n))[z]) for z in zs]
        ax.legend(proxies, labels)

    plt.show()

    return ax



def myzip(n, *iterables):
    """
    Same as the built-in zip function, except that some of the 'iterables' can
    be non-iterables (i.e. scalars). These scalars keep their values throughout
    the iteration.

    n is a list of booleans, with len(n) == len(*iterables), that specifies
    which of the arguments are iterable (True) and which are not (False).

    Example:

    >>> a = myzip([False,True,False], 'a', ['1','2','3'], '*')
    >>> for x,y,z in a: print(x,y,z)
    a 1 *
    a 2 *
    a 3 *

    """
    class aux:
        def __init__(self, data, iterate):
            self.func = iter(data).__next__ if iterate else lambda: data

        def __next__(self):
            return self.func()

    # this part is exactly the built-in 'zip' function, except calling
    # 'aux' instead of 'iter'
    sentinel = object()
    iterators = [aux(it, b) for (it, b) in zip(iterables, n)]
    while iterators:
        result = []
        for it in iterators:
            elem = next(it, sentinel)
            if elem is sentinel:
                return
            result.append(elem)
        yield tuple(result)



class tabbed:
    """
    Plot multiple figures in a single window, with each figure in its own tab.

    This is useful, e.g., for an experiment where a parameter is varied and
    multiple measurements are taken for each value. The data can then be
    presented with each tab containing all the plots for each parameter value.

    This function (class, in fact) can be called in several forms:

        1. With a single argument, a list of figures. This is more flexible
           as the figures can be anything.

        2. With the data to be plotted as vectors or matrices. In this case the
           matrices are iterated, the vectors not.

           In these examples, v is a 1xN vector and M is a 3xN matrix. The
           window will contain 3 tabs.

               tabbed(v, M, 'r')     y values change across tabs, x same

               tabbed(v, M,  ['r', 'g', b'])   y values change across tabs,
                                                x same; color changes also

               tabbed(M1, M2, 'b')   both x, y iterated across tabs

           Keyword arguments (for the axes, not the data) can be used as well,
           either as scalars or lists of length equal to the number of tabs.
           In the latter case they are iterated.

           Example:

               tabbed(v, M, marker='o', title=['a', 'b', 'c'], xlabel='x')

    """
    def __init__(self, *args, **kwargs):

        if plt.get_backend() != 'Qt5Agg':
            raise(RuntimeError('Must use Qt5 backend for tabbed figures.'))

        # See:
        # http://stackoverflow.com/questions/29681099/python-matplotlib-and-pyqt-for-multi-tab-plotting-navigation
        from PyQt5 import QtWidgets as qt
        from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
        from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar


        self.w = qt.QMainWindow(None)

        main_frame = qt.QWidget()
        self.tabWidget = qt.QTabWidget(main_frame)

        if len(args)==1 and type(args[0][0]) is plt.Figure:
            self.canvas_list = [FigureCanvas(fig) for fig in args[0]]
        else:
            self.canvas_list = []

            n = [type(x) is list or (type(x) is np.ndarray and x.ndim==2) for x in args]
            nd = [type(v) is list for v in kwargs.values()]

            if kwargs:
                for a, v in zip(myzip(n, *args), myzip(nd, *kwargs.values())):

                    fig = plt.Figure()
                    fig.add_subplot(111, **dict(zip(kwargs.keys(), v))).plot(*a)

                    self.canvas_list.append(FigureCanvas(fig))
            else:
                for a in myzip(n, *args):

                    fig = plt.Figure()
                    fig.add_subplot(111).plot(*a)

                    self.canvas_list.append(FigureCanvas(fig))

        self.names = ['Tab %i'%i for i in range(len(self.canvas_list))]

        for (c, n) in zip(self.canvas_list, self.names):
            self.tabWidget.addTab(c, n)

        self.toolbar = NavigationToolbar(self.canvas_list[0], self.w)

        self.vbox = qt.QVBoxLayout()
        self.vbox.addWidget(self.toolbar)
        self.vbox.addWidget(self.tabWidget)

        self.tabWidget.currentChanged.connect(self.func)

        main_frame.setLayout(self.vbox)
        self.w.setCentralWidget(main_frame)
        self.w.show()

    def func(self):
        self.toolbar.canvas = self.canvas_list[self.tabWidget.currentIndex()]
        self.canvas_list[self.tabWidget.currentIndex()].toolbar = self.toolbar
        #self.toolbar.__init__(self.canvas_list[self.tabWidget.currentIndex()], self.w)



def to_clipboard(fig):
    """
    Copy figure fig to clipboard
    """

    from PyQt5.QtWidgets import QApplication
    from PyQt5.QtGui import QImage

    if type(fig) is int:
        fig = plt.figure(fig)

    fig.set_facecolor('white')
    fig.canvas.draw()

    size = fig.canvas.size()
    im = QImage(fig.canvas.buffer_rgba(), size.width(), size.height(),
                QImage.Format_RGBA8888)

    QApplication.clipboard().setImage(im)

    return im



def to_pdf(fname, figs):

    from matplotlib.backends.backend_pdf import PdfPages

    pp = PdfPages(fname)

    for f in figs:
        pp.savefig(f)

    pp.close()



def inset(ax, pos, **kwargs):
    """
    Create an inset to a given axes, using normalized units

    ax  - the parent axes
    pos - [x0, y0, w, h]

    Note: as of Matplotlib 3.0, there is also the experimental Axes.axes_inset
    """

    from matplotlib.transforms import Bbox

    fig = ax.get_figure()

    b = Bbox.from_bounds(*pos).transformed(ax.transAxes).transformed(fig.transFigure.inverted()).bounds

    axi = plt.axes(b, **kwargs)

    return axi