#!/usr/bin/python
# -*- encoding: utf-8 -*-

import os.path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

class MonoSF:
    """
    A class for storing and working with monochromator results in the format
    of the single_file interface

    Since pandas handling of nan values works as we want it, we'll go
    with using pandas.

    The data is stored in self.df as a pandas.DataFrame
    The wavelengths from lowest to highest are stored in self.wls
    """

    def __init__(self, filename):
        """
        Load a file from memory and fill in this class.
        Also perform outlier removal.
        """

        self._consolidated_buffer = None

        self._filename = filename
        self.name = os.path.splitext(os.path.basename(self._filename))[0]

        raw = np.loadtxt(filename)
        self.df = pd.DataFrame(raw[:, 1:],
            index=pd.Index(raw[:, 0], name="Wavelength")
        )

        self.wls = self.df.index.drop_duplicates()

        self.remove_outliers()

    def remove_outliers(self, iqr_factor=1.5):
        """
        For each wavelength, remove higher outliers using the
        Interquartile Distance Method. Designed to remove cosmic particle
        PMT bursts.
        """
        wls = self.wls

        quantiles = np.array([
            np.quantile(
                self.df.loc[i, :].to_numpy().ravel(),
                [0.25, 0.75]
            ) for i in wls
        ])
        iqr = quantiles[:, 1] - quantiles[:, 0]
        upper_bound = pd.Series(
            quantiles[:, 1] + iqr * iqr_factor,
            index = wls
        )

        self.df = self.df.apply(
            lambda s: s.where( s < upper_bound.loc[s.name], np.nan ),
            axis=1
        )

    def consolidate_data(self):
        """
        Return data in a consolidated form as a pd.DataFrame
        with wavelengths as index.

        Returns a pd.DataFrame. The index are the wavelengths
        (self.wls), the first column contains some sort of
        quantification and the second column some sort of error
        measure.

        In this implementation, mean and std are used as measures.
        """
        if self._consolidated_buffer is None:
            self._consolidated_buffer = self._unravel().apply(
                lambda s: pd.Series(
                    [ s.mean(), s.std()],
                    index=["mean", "std"]
                ),
                axis=1
            )

        return self._consolidated_buffer

    def plot_spectrum(self, savefig=False, show=True,
        errbars=True, ax=False):
        """
        Plot wavelength spectrum of the data.

        If savefig is set to a path, saves the figure there.
        If show is set to False, don't show it in the end.
        If errbars is set to False, don't plot the errorbars
        If ax is set to an Axes object, use that one for plotting
        """
        x = self.wls
        yraw = self.consolidate_data()
        y = yraw.iloc[:, 0]
        yerr = yraw.iloc[:, 1] if errbars else None

        ax = ax or plt.gca()

        ax.errorbar(x, y, yerr,
            label=self.name,
            capsize=3, elinewidth=0.5, capthick=0.5)

        ax.set_xlabel("Wavelength [nm]")
        ax.set_ylabel("Mean Counts per Measurement")
        ax.grid()
        
        
        if show or savefig:
            #ax.set_title(self.name)
            ax.legend()

        if savefig:
            savefig = os.path.splitext(self._filename)[0] + ".png" \
                if type(savefig) is bool else savefig
            plt.savefig(savefig,dpi=300)

        if show:
            plt.show()


    def calc_transmission(self, mydata, refdata):
        """
        Return the transmission of mydata compared to refdata, as well
        as the linearly approximated error.

        The error is approximated by the root of the squared sum of the
        relative errors, transformed back to an absolute error.
        """
        transmission = (mydata.iloc[:, 0]-mydata.iloc[0,0]) / (refdata.iloc[:, 0])

        error = np.sqrt(
            (mydata.iloc[:, 1]/mydata.iloc[:, 0])**2 +
            (refdata.iloc[:, 1]/refdata.iloc[:, 0])**2
        ) * transmission

        return pd.DataFrame([ transmission, error ],
            index=["transmission", "std" ]).T


    def compare_to_reference(self, ref,
        errbars=True, savefig=False, show=True):
        """
        Compare a spectrum to a reference spectrum `ref`, which needs to
        be another compatible spectrum instance and plot the result.

        If savefig is true or a path, save the figure.
        """
        refdata = ref.consolidate_data()
        mydata = self.consolidate_data()

        t = self.calc_transmission(mydata, refdata)

        ax2 = plt.gca()
        ax = ax2.twinx()

        ref.plot_spectrum(show=False, ax=ax2, errbars=errbars)
        self.plot_spectrum(show=False, ax=ax2, errbars=errbars)

        ax.errorbar(t.index, t.iloc[:, 0],
            t.iloc[:, 1] if errbars else None,
            label="Transmission", c="green",
            capsize=3, elinewidth=0.5, capthick=0.5)

        ax.set_ylim(0, 1.1)
        ax.set_ylabel("Transmission", c="green")
        ax.set_xlabel("Wavelength [nm]")

        if show or savefig:
            ax.set_title("Transmission plot of " + self.name)
            ax.legend(loc="upper right")
            ax2.legend(loc="upper left")
            ax.grid(axis="y")
            ax2.grid(axis="x")
            plt.tight_layout()

        if savefig:
            savefig = os.path.splitext(self._filename)[0] + "_transmission.png" \
                if type(savefig) is bool else savefig
            plt.savefig(savefig,dpi=300)

        if show:
            plt.show()


    def plot_time_drift(self, wl=None,
        savefig=False, show=True, ax=False
        ):
        """
        Polt time-drift over a certain wavelength. If no wavelength is
        given, the first one is used.
        """
        data = self._time_unravel()

        x = range(0, data.shape[1])
        y = data.loc[wl] if wl is not None else data.iloc[0]
        lbl = y.name

        ax = ax or plt.gca()
        ax.plot(x, y, ".", ms=1, label=str(y.name) + " nm")
        ax.set_xlabel("Measurement intervals")
        ax.set_ylabel("Counts per Measurement interval")

        if show or savefig:
            ax.grid()
            ax.legend()
            ax.set_title("Time Drift Plot for " + self.name)
            plt.tight_layout()

        if savefig:
            savefig = os.path.splitext(self._filename)[0] + "_timedrift.png" \
                if type(savefig) is bool else savefig
            plt.savefig(savefig)

        if show:
            plt.show()



# ##### HELPER FUNCTIONS ###############################################

    def _get_sweeps(self):
        """
        Chops the dataframe into sweeps to be fed to concat.
        """
        itsize = int(self.wls.shape[0])
        iterations = int(self.df.shape[0] / itsize )

        return [
            self.df.iloc[ i*itsize : (i+1)*itsize ] for i in range(0, iterations)
        ]

    def _unravel(self):
        """
        Bring the DataFrame into a shape suitable for wavelength
        analysis - that is, concat all columns with the same wavelength

        Returns the new DataFrame. Doesn't modify self.df
        """
        return pd.concat(
            self._get_sweeps(),
            axis=1)

    def _time_unravel(self):
        """
        Fill gaps between the sweeps with NaNs to account for "lost"
        points in time so that the resulting data more or less
        represents a timed measurement (not taking into account the
        grating rotation duration of a couple of seconds each time)
        """
        parts = self._get_sweeps()

        fillsize = (
            int(parts[0].shape[0]),
            int((self.wls.shape[0]-1) * parts[0].shape[1])
        )
        fill = pd.DataFrame(
            np.full(fillsize, np.nan),
            index=parts[0].index
        )

        stuffed = []
        for i in parts:
            stuffed.append(i)
            stuffed.append(fill)

        return pd.concat(stuffed, axis=1)
