import pandas as _pd
import numpy as _np
import os as _os
import matplotlib.pyplot as _plt
import warnings as _warnings

from ._core import (signals as _signals, 
                    multiprocessing as _multiprocessing, 
                    fitfiles as _fitfiles, 
                    polarities as _polarities, 
                    save as _save,
                    _load_file)
from ._signal import Signal as _Signal


def generate(fitfile : _pd.DataFrame, 
             calibration : _pd.DataFrame = None, 
             jitter : str = None, 
             pretrigger : int = None, 
             multiprocessing : bool = False, 
             processes : int = _os.cpu_count()) -> _pd.DataFrame:
    if 'BASE' in _polarities(fitfile):
        _warnings.warn(f"The fitfile contains `BASE` polarity.", UserWarning)
    
    iterables = _template_iterable(fitfile, calibration, jitter, pretrigger)
    template_list = _multiprocessing(function = _template_pixel, iterables = iterables,
                                     multiprocessing = multiprocessing, processes = processes,
                                     desc = 'templates', unit = 'pixel')

    templates_dict = {}

    for template in template_list:
        if template.channel not in templates_dict:
            templates_dict[template.channel] = {}
        templates_dict[template.channel][template.polarity] = template

    templates = _pd.DataFrame(templates_dict)

    return templates


def _template_iterable(fitfile : _pd.DataFrame,
                       calibration : _pd.DataFrame = None,
                       jitter : str = None,
                       pretrigger : int = None):
    for _, _, fitfile_pixel in _fitfiles(fitfile, iterables = ['channel', 'polarity']):
        yield (fitfile_pixel, calibration, jitter, pretrigger)


def _template_pixel(fitfile : _pd.DataFrame,
                    calibration : _pd.DataFrame = None,
                    jitter : str = None,
                    pretrigger : int = None) -> _Signal:
    signals = _template_signals(fitfile, calibration = calibration, jitter = jitter, pretrigger = pretrigger)

    template = _Signal.template(signals)
    template.pretrigger = pretrigger
    template.normalize(inplace = True)

    if template.nan():
        _warnings.warn(f"Template for `{template.channel}`, `{template.polarity}` contains NaN values.", Warning)

    return template


def _template_signals(fitfile : _pd.DataFrame,
                      calibration : _pd.DataFrame = None,
                      jitter : str = None,
                      pretrigger : int = None):
    if jitter is not None:
        target_jitter = _np.nanmedian(fitfile[jitter])

    for pulse in _signals(fitfile, calibration = calibration):

        if jitter is not None:
            pulse.jitter(target_jitter - pulse.header[jitter], interpolate = True, inplace = True)

        if pretrigger is not None:
            pulse.pretrigger = pretrigger
        
        pulse.normalize(inplace = True)

        yield pulse


def plot(templates : _pd.DataFrame) -> None:
    for channel, entry in templates.items():
        for polarity, pulse in entry.items():

            if isinstance(pulse, _Signal):
                _plt.figure()
                pulse.plot()
                _plt.title(channel + ' ' + polarity)
                _plt.tight_layout
                _plt.show()
            else:
                _warnings.warn(f"No template for pixel `{channel}`, `{polarity}`", UserWarning)
                      
    return None
        

def base(templates : _pd.DataFrame, polarity : str = 'POSP') -> _pd.DataFrame:
    templates.loc['BASE'] = templates.loc[polarity]

    return templates


def save(templates : _pd.DataFrame, file : str) -> None:
    _, ext = _os.path.splitext(file)

    if ext != '.pkl':
        raise NotImplementedError

    _save(templates, file)

    return None


def load(file : str) -> _pd.DataFrame:
    _, ext = _os.path.splitext(file)

    if ext != '.pkl':
        raise NotImplementedError

    templates = _load_file(file)

    return templates


def height(templates : _pd.DataFrame) -> _pd.DataFrame:
    height_dict = {}

    for channel, entry in templates.items():
        height_dict[channel] = {}
        for polarity, pulse in entry.items():
            if isinstance(pulse, _Signal):
                height_dict[channel][polarity] = pulse.height()    
                
    height = _pd.DataFrame(height_dict)
    
    return height


def fom(templates : _pd.DataFrame) -> _pd.DataFrame:
    fom_dict = {}

    for channel, entry in templates.items():
        fom_dict[channel] = {}
        for polarity, pulse in entry.items():
            if isinstance(pulse, _Signal):
                fom_dict[channel][polarity] = pulse.fom()
                
    fom = _pd.DataFrame(fom_dict)
    
    return fom


def shape(templates : _pd.DataFrame, risetimes: int = 1, decaytimes: int = 1) -> _pd.DataFrame:
    shape_dict = {}

    for channel, entry in templates.items():
        shape_dict[channel] = {}
        for polarity, pulse in entry.items():
            if isinstance(pulse, _Signal):
                shape_dict[channel][polarity] = pulse.shape(risetimes = risetimes, decaytimes = decaytimes)

    shape = _pd.DataFrame(shape_dict)

    return shape