import pandas as _pd
import numpy as _np
import os as _os

import warnings as _warnings
from lmfit import Model as _Model
from scipy import fft as _fft, signal as _signal
import re as _re

from ._signal import Signal as _Signal
from ._frequencyspectrum import Frequencyspectrum as _Frequencyspectrum

from ._core import (extend as _extend, 
                    append as _append, 
                    multiprocessing as _multiprocessing, 
                    batch as _batch, 
                    signals as _signals, 
                    fitfiles as _fitfiles)

from .generator import _pulse


def fit(fitfile: _pd.DataFrame, 
        mask: _pd.Series = None, 
        methods: list = ['general', 'matched'], 
        prefix: str = None, 
        suffix: str = None, 
        templates: _pd.DataFrame = None, 
        noisespectra: _pd.DataFrame = None, 
        calibration: _pd.DataFrame = None, 
        window: str = ('tukey', 0.1), 
        pretrigger: int = 4000, 
        posttrigger: int = 4384, 
        max_jitter: int = 64,
        risetimes: int = 1,
        decaytimes: int = 1,
        batch_size: int = 10000,
        multiprocessing: bool = False, 
        processes: int = _os.cpu_count()) -> _pd.DataFrame:
    methods = _np.array(methods, ndmin=1)

    if mask is None:
        mask = _pd.Series(True, index = fitfile.index)

    iterables = _fit_iterable(fitfile[mask], methods, templates, noisespectra, calibration, window, pretrigger, posttrigger, max_jitter, risetimes, decaytimes, batch_size)
    fitfilesFit = _multiprocessing(_fit_batch, iterables, multiprocessing, processes, desc = 'fit', unit = 'batch')

    fitfileFit = _append(fitfilesFit)

    fitfile = _extend([fitfile, fitfileFit], prefix = [None, prefix], suffix = [None, suffix])

    return fitfile


def _fit_iterable(fitfile: _pd.DataFrame, 
                 methods: list = ['general', 'matched'], 
                 templates: _pd.DataFrame = None,
                 noisespectra: _pd.DataFrame = None,
                 calibration: _pd.DataFrame = None,
                 window: str = ('tukey', 0.1),
                 pretrigger: int = 4000, 
                 posttrigger: int = 4384, 
                 max_jitter: int = 64,
                 risetimes: int = 1,
                 decaytimes: int = 1,
                 batch_size: int = 10000):
    for channel, polarity, fitfile in _fitfiles(fitfile, iterables = ['channel', 'polarity']):
        
        template = None
        noisespectrum = None

        try:
            template = templates[channel][polarity]
        except:
            pass

        try:
            noisespectrum = noisespectra[channel]['BASE']
        except:
            pass
        
        for batch in _batch(fitfile, batch_size = batch_size):
            yield (batch, methods, template, noisespectrum, calibration, window, pretrigger, posttrigger, max_jitter, risetimes, decaytimes)


def _fit_batch(fitfile: _pd.DataFrame,
               methods: list = ['general', 'matched'],
               template: _pd.DataFrame = None,
               noisespectrum: _pd.DataFrame = None,
               calibration: _pd.DataFrame = None,
               window: str = ('tukey', 0.1),
               pretrigger: int = 4000,
               posttrigger: int = 4384,
               max_jitter: int = 64,
               risetimes: int = 1,
               decaytimes: int = 1) -> _pd.DataFrame:
    fits = [['index', _index()]]

    if 'header' in methods:
        fit = _header()
        fits.append(['header', fit])

    if 'general' in methods:
        fit = _general(pretrigger = pretrigger)
        fits.append(['general', fit])

    if 'template' in methods:
        if isinstance(template, _Signal):
            fit = _template(template = template, pretrigger = pretrigger)
            fits.append(['template', fit])
        else:
            _warnings.warn('No template provided for `template` fit.')

    if 'jtemplate' in methods:
        if isinstance(template, _Signal):
            fit = _jtemplate(template = template, pretrigger = pretrigger, posttrigger = posttrigger, max_jitter = max_jitter)
            fits.append(['jtemplate', fit])
        else:
            _warnings.warn('No template provided for `jtemplate` fit.')

    if 'opti' in methods:
        if isinstance(template, _Signal) and isinstance(noisespectrum, _Frequencyspectrum):
            fit = _opti(template = template, noisespectrum = noisespectrum, window = window, pretrigger = pretrigger, posttrigger = posttrigger, max_jitter = max_jitter)
            fits.append(['opti', fit])
        else:
            _warnings.warn('No template or noise spectrum provided for `opti` fit.')

    if 'matched' in methods:
        if isinstance(template, _Signal):
            fit = _matched(template = template, window = window, pretrigger = pretrigger, posttrigger = posttrigger, max_jitter = max_jitter)
            fits.append(['matched', fit])
        else:
            _warnings.warn('No template provided for `matched` fit.')

    if 'wmatched' in methods:
        if isinstance(template, _Signal) and isinstance(noisespectrum, _Frequencyspectrum):
            fit = _matched(template = template, noisespectrum = noisespectrum, window = window, pretrigger = pretrigger, posttrigger = posttrigger, max_jitter = max_jitter)
            fits.append(['wmatched', fit])
        else:
            _warnings.warn('No template or noise spectrum provided for `wmatched` fit.')

    if 'shape' in methods:
        if isinstance(template, _Signal):
            fit = _shape(template = template, risetimes = risetimes, decaytimes = decaytimes, pretrigger = pretrigger, max_jitter = max_jitter)
            fits.append(['shape', fit])
        else:
            _warnings.warn('No template provided for `shape` fit.')

    results = []

    for signal in _signals(fitfile, calibration = calibration):
        result = {}

        for name, fit in fits:
            parameter = fit(signal)

            if name == 'index':
                resultFit = parameter
            else:
                resultFit = {name + '_' + str(key): val for key, val in parameter.items()}

            result.update(resultFit)
            
        results.append(result)
    
    fitfile_pixel = _pd.DataFrame(results)

    return fitfile_pixel


def _index():
    def fit(signal : _Signal) -> dict:

        signal = signal.copy()

        parameter = {'folder' : signal.folder,
                     'measurement' : signal.measurement,
                     'channel' : signal.channel,
                     'polarity' : signal.polarity,
                     'signal' : signal.signal}

        return parameter
    
    return fit


def _header():
    index_columns = ['folder', 'measurement', 'channel', 'polarity', 'signal']

    def fit(signal : _Signal) -> dict:

        signal = signal.copy()

        parameter = signal.header

        for index in index_columns:
            if index in parameter:
                parameter.pop(index)

        return parameter

    return fit


def _general(pretrigger : int = 4000):
    def fit(signal : _Signal) -> dict:

        signal = signal.copy()
        signal.pretrigger = pretrigger

        parameter = {"mean": signal.mean(),
                     "mean_pretrigger": signal.mean_pretrigger(),
                     "std": signal.std(),
                     "std_pretrigger": signal.std_pretrigger(),
                     "slope": signal.slope(),
                     "slope_pretrigger": signal.slope_pretrigger(),
                     "min": signal.min(),
                     "max": signal.max(),
                     "max_difference": signal.max_difference(),
                     "height": signal.height(),
                     "area": signal.area(),
                     "clipped": signal.clipped(),
                     "traceless": signal.traceless()}

        return parameter

    return fit


def _template(template,
              pretrigger: int = 4000):
    template = template.copy()
    template.pretrigger = pretrigger
    template.normalize(inplace = True)

    template_mean = _np.mean(template.data)
    template_var = _np.mean((template.data - template_mean)**2)

    def fit(signal : _Signal) -> dict:

        signal = signal.copy()

        signal_mean = _np.mean(signal.data)
        projection = _np.average(signal.data * template.data)

        amplitude = (projection - signal_mean * template_mean) / template_var
        offset = signal_mean - amplitude * template_mean
        difference = signal.data - (amplitude * template.data) - offset
        chi2 = _np.nanmean(difference**2)

        parameter = {'amplitude': amplitude,
                     'offset': offset,
                     "chi2": chi2}

        return parameter
    
    return fit


def _jtemplate(template,
           pretrigger: int = 4000,
           posttrigger: int = 4384,
           max_jitter: int = 64):
    template = template.copy()

    fit_template = _template(template.trim(max_jitter), pretrigger = pretrigger - max_jitter)

    def fit(signal : _Signal) -> dict:

        signal = signal.copy()

        signals = [signal.jitter(-jitter) for jitter in range(-max_jitter, max_jitter + 1)]
        signals = [signal.trim(max_jitter) for signal in signals]
        parameters = [fit_template(signal) for signal in signals]

        amplitude, amplitude_base, offset, jitter = _jtemplate_evaluate(parameters)

        parameter = {'amplitude': amplitude,
                     'offset': offset,
                     'jitter': jitter}
        
        if signal.polarity == 'BASE':
            parameter.update({'amplitude': amplitude_base,
                              'jitter': 0.0})

        parameter_chi2 = _parameter_chi2(signal, template, amplitude, jitter = jitter,
                                            pretrigger = pretrigger, posttrigger = posttrigger)
        parameter.update(parameter_chi2)

        return parameter
    
    return fit


def _jtemplate_evaluate(parameters : list) -> tuple:
    arg = _np.argmin([parameter['chi2'] for parameter in parameters])
    max_jitter = len(parameters) // 2
    
    amplitude = parameters[arg]['amplitude']
    amplitude_base = parameters[max_jitter]['amplitude']
    offset = parameters[arg]['offset']
    jitter = arg - max_jitter
    jitter_max = (abs(jitter) == max_jitter)

    if not jitter_max:
        p1, p2, p3 = [(jitter + offset, parameters[arg + offset]['chi2']) for offset in (-1, 0, 1)]
        a, b, c = _solve_quadratic(p1, p2, p3)
        q1, q2, q3 = [(jitter + offset, parameters[arg + offset]['amplitude']) for offset in (-1, 0, 1)]
        d, e, f = _solve_quadratic(q1, q2, q3)
        jitter = -b / (2 * a)
        amplitude = d * jitter**2 + e * jitter + f

    return amplitude, amplitude_base, offset, jitter


def _opti(template: _Signal, 
          noisespectrum: _Frequencyspectrum, 
          window: str = ('tukey', 0.1),
          pretrigger: int = 4000,
          posttrigger: int = 4384,
          max_jitter: int = 64):
    template = template.copy()
    template.pretrigger = pretrigger
    template.normalize(inplace = True)

    filter_matched = template.trim(max_jitter).filter_matched(noisespectrum = noisespectrum, window = window)

    convolve = _opti_convolve(template, filter_matched, window = window, max_jitter = max_jitter)
    template_amplitude, template_amplitude_base, template_jitter = _evaluate_convolve(convolve)

    def fit(signal: _Signal) -> dict:

        signal = signal.copy()
        signal.pretrigger = pretrigger
        signal.normalize(inplace = True)

        convolve = _opti_convolve(signal, filter_matched, window = window, max_jitter = max_jitter)
        signal_amplitude, signal_amplitude_base, signal_jitter = _evaluate_convolve(convolve)

        amplitude = signal_amplitude / template_amplitude
        amplitude_base = signal_amplitude_base / template_amplitude_base
        jitter = signal_jitter - template_jitter

        parameter = {"amplitude": amplitude,
                     "jitter": jitter}
        
        if signal.polarity == 'BASE':
            parameter.update({'amplitude': amplitude_base,
                              'jitter': 0.0})

        parameter_chi2 = _parameter_chi2(signal, template, amplitude, jitter = jitter,
                                            pretrigger = pretrigger, posttrigger = posttrigger)
        parameter.update(parameter_chi2)
    
        return parameter

    return fit


def _opti_convolve(signal: _Signal, filter_matched: _Signal, window: str = ('tukey', 0.1), max_jitter: int = 64):
    signals = [signal.jitter(-jitter) for jitter in range(-max_jitter, max_jitter + 1)]
    signals = [signal.trim(max_jitter) for signal in signals]
    signals = [signal.window(window) for signal in signals]

    convolve = _np.sum([_signal.convolve(filter_matched.data, signal.data, mode='valid') for signal in signals], axis = 1)

    return convolve


def _matched(template: _Signal,
             noisespectrum: _Frequencyspectrum = None,
             window: str = ('tukey', 0.1),
             pretrigger: int = 4000,
             posttrigger: int = 4384,
             max_jitter: int = 64):
    template = template.copy()
    template.pretrigger = pretrigger
    template.normalize(inplace = True)

    filter_matched = template.trim(max_jitter).filter_matched(noisespectrum = noisespectrum, window = window)

    convolve = _signal.convolve(template, filter_matched, mode = 'same')
    template_amplitude, template_amplitude_base, template_jitter = _evaluate_convolve(convolve)
    
    def fit(signal : _Signal) -> dict:

        signal = signal.copy()
        signal.pretrigger = pretrigger
        signal.normalize(inplace = True)

        convolve = _signal.convolve(signal, filter_matched, mode = 'same')
        signal_amplitude, signal_amplitude_base, signal_jitter = _evaluate_convolve(convolve)

        amplitude = signal_amplitude / template_amplitude
        amplitude_base = signal_amplitude_base / template_amplitude_base
        jitter = signal_jitter - template_jitter

        parameter = {"amplitude": amplitude,
                     "jitter": jitter}
        
        if signal.polarity == 'BASE':
            parameter.update({'amplitude': amplitude_base,
                              'jitter': 0.0})

        parameter_chi2 = _parameter_chi2(signal, template, amplitude, jitter = jitter,
                                         pretrigger = pretrigger, posttrigger = posttrigger)
        parameter.update(parameter_chi2)

        return parameter

    return fit


def _shape(template: _Signal, 
           risetimes: int = 1, 
           decaytimes: int = 1, 
           pretrigger: int = 4000,
           max_jitter: int = 64):
    template = template.copy()
    template.pretrigger = pretrigger

    template_shape = template.shape(risetimes = risetimes, decaytimes = decaytimes)
    fit_jtemplate = _jtemplate(template, pretrigger = pretrigger, max_jitter = max_jitter)
    
    model = _Model(_pulse)

    template_params = model.make_params()
    for key in template_shape:
        if key == 'chi2':
            continue
        template_params.add(key, value = template_shape[key])
        if key.endswith('_factor'):
            template_params[key].set(min = 0, max = 1)
        elif key.startswith('rise') or key.startswith('decay'):
            template_params[key].set(min = 10e-9)
            
    for method in ['rise', 'decay']:
        number_set = set([int(_re.findall(r'\d+', entry)[0]) for entry in template_params if entry.startswith(method)])
        if len(number_set) > 1:
            number_last = max(number_set)

            expr = '1'
            for i in number_set:
                if i == number_last:
                    continue
                expr += f' - {method}{i}_factor'

            template_params[f'{method}{number_last}'].set(expr = expr)

    sampling = template.samplingrate / template.oversampling

    def fit(signal: _Signal) -> dict:

        signal = signal.copy()
        signal.pretrigger = pretrigger

        fit_params = fit_jtemplate(signal)

        params = template_params.copy()
        params['amplitude'].set(value = fit_params['amplitude'] * template_params['amplitude'].value)
        params['offset'].set(value = fit_params['offset'])
        params['jitter'].set(value = template_params['jitter'].value + fit_params['jitter'] / sampling)

        x = signal.time()
        y = signal.data
        
        try:
            result = model.fit(y, x = x, params = params)

            parameter = result.best_values
            parameter['chi2'] = result.chisqr / len(y)

        except:
            parameter = {}

        return parameter
    
    return fit


def _parameter_chi2(signal: _Signal,
                       template: _Signal,
                       amplitude: float,
                       jitter: float = None,
                       trim: int = None,
                       pretrigger: int = 4000,
                       posttrigger: int = 4384):
    signal = signal.copy()
    template = template.copy()

    if jitter is not None:
        jitter_signal = _np.round(jitter)
        jitter_template = jitter_signal - jitter
        signal.jitter(-jitter_signal, inplace = True)
        template.jitter(-jitter_template, interpolate = True, inplace = True)

    if trim is not None:
        signal = signal.trim_nan(trim)
        template = template.trim_nan(trim)

    difference = signal.data - (amplitude * template.data)
    difference = difference - _np.nanmean(difference)
    difference_squared = difference**2

    chi2 = _np.nanmean(difference_squared)
    chi2_pretrigger = _np.nanmean(difference_squared[:pretrigger])
    chi2_trigger = _np.nanmean(difference_squared[pretrigger:posttrigger])
    chi2_posttrigger = _np.nanmean(difference_squared[posttrigger:])

    parameter = {"chi2": chi2,
                 "chi2_pretrigger": chi2_pretrigger,
                 "chi2_trigger": chi2_trigger,
                 "chi2_posttrigger": chi2_posttrigger}

    return parameter


def _evaluate_convolve(convolve: _np.array) -> tuple:
    length = len(convolve)
    center = length // 2
    peak = _np.argmax(convolve)
    
    if (peak == 0) or (peak == length - 1):
        amplitude = convolve[peak]
        jitter = peak
    else:
        p1, p2, p3 = [(peak + offset, convolve[peak + offset]) for offset in (-1, 0, 1)]
        a, b, c = _solve_quadratic(p1, p2, p3)

        amplitude = -(b**2) / (4 * a) + c
        jitter = -b / (2 * a) - center

    amplitude_base = convolve[center]

    return amplitude, amplitude_base, jitter


def _solve_quadratic(p1: tuple, p2: tuple, p3: tuple) -> tuple:
    (x1, y1), (x2, y2), (x3, y3) = p1, p2, p3

    denom = (x1 - x2) * (x1 - x3) * (x2 - x3)

    a = (x3 * (y2 - y1) + x2 * (y1 - y3) + x1 * (y3 - y2)) / denom
    b = (x3 * x3 * (y1 - y2) + x2 * x2 * (y3 - y1) + x1 * x1 * (y2 - y3)) / denom
    c  = (x2 * x3 * (x2 - x3) * y1 + x3 * x1 * (x3 - x1) * y2 + x1 * x2 * (x1 - x2) * y3) / denom

    return (a, b, c)