import pandas as _pd
import numpy as _np
import matplotlib.pyplot as _plt
import lmfit as _lmfit
import datetime as _datetime
import multiprocessing as _mp

from scipy import spatial as _spatial, odr as _odr

from .paqs import (summary as _summary,
                   settings as _settings)
from ._core import (paths as _paths,
                    masks as _masks,
                    channels as _channels, 
                    polarities as _polarities)
from ._utils import _get_unique_value


# sigma


def assign_sigma_peak(fitfile: _pd.DataFrame,
                      mask: _pd.Series = None,
                      column: str = 'general_std',
                      column_sigma: str = None,
                      method: str = 'pixel',
                      sigma: float = 10.0,
                      bins: int = 1000,
                      plot: bool = False) -> _pd.DataFrame:
    if mask is None:
        mask = _pd.Series(True, index = fitfile.index)

    if column_sigma is None:
        column_sigma = column + '_sigma'

    mask_na = _pd.isna(fitfile[column])

    for mask_method in _masks_method(fitfile, method = method):
        mask_peak = mask & mask_method & -mask_na

        if sum(mask_peak) > 0:
            
            for _ in range(3):
                data = fitfile[mask_peak][column]

                fit_center = _np.median(data)
                fit_sigma = _np.std(data)

                peak_range = [fit_center - sigma * fit_sigma, fit_center + sigma * fit_sigma]
                mask_peak = mask_peak & (fitfile[column] > peak_range[0]) & (fitfile[column] < peak_range[1])

            for _ in range(3):
                peak_range = [fit_center - sigma * fit_sigma, fit_center + sigma * fit_sigma]
                mask_peak = mask_peak & (fitfile[column] > peak_range[0]) & (fitfile[column] < peak_range[1])

                data = fitfile[mask_peak][column]

                bins_list = _np.linspace(*peak_range, bins + 1)
                hist, bin_edges = _np.histogram(data, bins = bins_list)
                bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2

                model = _lmfit.models.GaussianModel()
                params = model.guess(hist, x = bin_centers)
                result = model.fit(hist, x = bin_centers, params = params)

                fit_center = result.best_values['center']
                fit_sigma = result.best_values['sigma']

            fitfile.loc[mask_method, column_sigma] = _np.abs(fitfile.loc[mask_method, column] - fit_center) / fit_sigma

            if plot:
                c = _np.where((bin_centers > fit_center - 1 * fit_sigma) & (bin_centers < fit_center + 1 * fit_sigma), 'C1',
                    _np.where((bin_centers > fit_center - 2 * fit_sigma) & (bin_centers < fit_center + 2 * fit_sigma), 'C2',
                    _np.where((bin_centers > fit_center - 3 * fit_sigma) & (bin_centers < fit_center + 3 * fit_sigma), 'C3', 'C4')))
                yfit = result.eval(x = bin_centers)
                yerr = result.eval_uncertainty(x = bin_centers, sigma = 1)
                ylim = [0, 1.1 * _np.max(hist)]
                title = _title_method(fitfile[mask_peak])
                
                _plt.figure()
                _plt.scatter(bin_centers, hist, c = c, s = 3)
                _plt.plot(bin_centers, yfit, color = 'C0')
                _plt.fill_between(bin_centers, yfit - yerr, yfit + yerr, color = 'C0', alpha = 0.5)
                _plt.xlim(peak_range)
                _plt.ylim(ylim)
                _plt.xlabel(column)
                _plt.title(title)
                _plt.tight_layout()
                _plt.show()

    return fitfile


def _title_method(fitfile) -> str:
    try:
        channel = _get_unique_value(fitfile, 'channel')
        try:
            polarity = _get_unique_value(fitfile, 'polarity')
            title = f'{channel} {polarity}'
        except:
            title = f'{channel}'
    except:
        title = f''

    return title


def _masks_method(fitfile: _pd.DataFrame,
                  mask: _pd.Series = None,
                  method: str = 'pixel'):
    if mask is None:
        mask = _pd.Series(True, index = fitfile.index)

    if method == 'detector':
        iterables = []
    elif method == 'channel':
        iterables = ['channel']
    elif method == 'pixel':
        iterables = ['channel', 'polarity']
    else:
        raise ValueError('Method must be `detector`, `channel` or `pixel`.')
    
    for result in _masks(fitfile, iterables = iterables):
        mask_method = tuple(result)[-1]
        mask_method = mask_method & mask
        if sum(mask_method) > 0:
            yield mask_method

def assign_center_drift(fitfile: _pd.DataFrame,
                        mask: _pd.Series = None,
                        column: str = 'general_std_pretrigger',
                        column_time: str = 'time',
                        column_center: str = None,
                        method: str = 'pixel',
                        width: _pd.Timedelta = _pd.Timedelta(100, unit = 's'),
                        plot: bool = False) -> _pd.DataFrame:
    if mask is None:
        mask = _pd.Series(True, index = fitfile.index)

    if column_center is None:
        column_center = column + '_center'

    mask_na = -fitfile[column_time].isna()

    for mask_method in _masks_method(fitfile, method = method):
        mask_drift = mask & mask_method

        if sum(mask_drift) > 0:

            for _ in range(3):
                data = fitfile[mask_drift][column]

                median = _np.median(data)
                std = _np.std(data)

                drift_range = [median - 5 * std, median + 5 * std]
                mask_drift = mask_drift & (fitfile[column] > drift_range[0]) & (fitfile[column] < drift_range[1])

            data_method = fitfile[mask_method & mask_na]
            data_drift = fitfile[mask_drift & mask_na].sort_values(column_time)

            x_method = data_method[column_time]
            x_drift = data_drift[column_time]
            y_drift = data_drift[column]
            x0 = _np.min(x_drift)

            x = (x_method - x0).apply(_datetime.timedelta.total_seconds)
            xp = (x_drift - x0).apply(_datetime.timedelta.total_seconds)
            fp = y_drift
            sigma = _datetime.timedelta.total_seconds(width)
            y = _filter_gaussian(x, xp, fp, sigma)

            fitfile.loc[mask_method & mask_na, column_center] = _pd.Series(y, index = x_method.index)

            if plot:
                x_filt = x_method.to_numpy()
                argsort = _np.argsort(x_filt)
                xlim = [_np.min(x_drift), _np.max(x_drift)]
                title = _title_method(fitfile[mask_drift])

                _plt.figure()
                _plt.scatter(x_drift, y_drift, c = 'C1', s = 3)
                _plt.plot(x_filt[argsort], y[argsort], color = 'C0')
                _plt.xlim(xlim)
                _plt.ylim(drift_range)
                _plt.xlabel(column_time)
                _plt.ylabel(column)
                _plt.title(title)
                _plt.tight_layout()
                _plt.show()
                
    return fitfile


def _filter_gaussian(x, xp, fp, sigma):

    x = _np.asarray(x)
    xp = _np.asarray(xp)
    fp = _np.asarray(fp)

    tree = _spatial.cKDTree(xp[:, None])

    args = [(x_i, xp, fp, sigma, tree) for x_i in x]

    with _mp.Pool() as pool:
        y = pool.map(_filter_gaussian_weights, args)

    y = _np.asarray(y)

    return y


def _filter_gaussian_weights(args):

    x_i, xp, fp, sigma, tree = args
    idx = tree.query_ball_point(x_i, 5 * sigma)

    if len(idx) == 0:
        return _np.nan

    distances = _np.abs(xp[idx] - x_i)
    weights = _np.exp(-0.5 * (distances / sigma) ** 2)
    weights /= weights.sum()

    y = _np.sum(weights * fp[idx])
    
    return y


def assign_median_shift(fitfile: _pd.DataFrame,
                        mask: _pd.Series = None,
                        column: str = 'chi2',
                        column_shift: str = 'amplitude',
                        column_median: str = None,
                        method: str = 'pixel',
                        width = None,
                        plot: bool = False) -> _pd.DataFrame:
    if mask is None:
        mask = _pd.Series(True, index = fitfile.index)

    if column_median is None:
        column_median = column + '_median'

    mask_na = -fitfile[column_shift].isna()

    for mask_method in _masks_method(fitfile, method = method):
        mask_drift = mask & mask_method

        if sum(mask_drift) > 0:

            data_method = fitfile[mask_method & mask_na]
            data_drift = fitfile[mask_drift & mask_na].sort_values(column_shift)

            x = data_method[column_shift]
            xp = data_drift[column_shift]
            fp = data_drift[column]
            y = _filter_median(x, xp, fp, width)

            fitfile.loc[mask_method & mask_na, column_median] = _pd.Series(y, index = x.index)

            if plot:
                x_filt = x.to_numpy()
                argsort = _np.argsort(x_filt)
                xlim = [_np.min(xp), _np.max(xp)]
                title = _title_method(fitfile[mask_drift])

                _plt.figure()
                _plt.scatter(xp, fp, c = 'C1', s = 3)
                _plt.plot(x_filt[argsort], y[argsort], color = 'C0')
                _plt.xlim(xlim)
                _plt.xlabel(column_shift)
                _plt.ylabel(column)
                _plt.title(title)
                _plt.tight_layout()
                _plt.show()
                
    return fitfile


def _filter_median(x, xp, fp, width):

    x = _np.asarray(x)
    xp = _np.asarray(xp)
    fp = _np.asarray(fp)

    tree = _spatial.cKDTree(xp[:, None])

    args = [(x_i, xp, fp, width, tree) for x_i in x]

    with _mp.Pool() as pool:
        y = pool.map(_filter_median_weights, args)

    y = _np.asarray(y)

    return y


def _filter_median_weights(args):

    x_i, xp, fp, width, tree = args
    idx = tree.query_ball_point(x_i, width)

    if len(idx) == 0:
        return _np.nan

    y = _np.median(fp[idx])
    
    return y


def assign_time_paqs(fitfile: _pd.DataFrame, 
                     column_timestamp: str = 'timestamp',
                     column_jitter: str = None,
                     column_time: str = 'time', 
                     column_timedelta: str = 'timedelta') -> _pd.DataFrame:
    paths = _paths(fitfile)

    summary = _summary(paths)
    summary = summary.set_index(['folder', 'measurement'])
    start_dict = summary['start'].to_dict()

    settings = _settings(paths)
    settings = settings.set_index(['folder', 'measurement'])
    samplingrate_dict = settings['samplingrate'].to_dict()
    oversampling_dict = settings['oversampling'].to_dict()

    for folder, measurement, mask in _masks(fitfile, iterables = ['folder', 'measurement']):
        start = start_dict[folder, measurement]
        samplingrate = samplingrate_dict[folder, measurement]
        fitfile.loc[mask, column_time] =  start + _pd.to_timedelta(fitfile[mask][column_timestamp] / samplingrate, unit = 's')

        if column_jitter is not None:
            oversampling = oversampling_dict[folder, measurement]
            sampling = samplingrate / oversampling
            fitfile.loc[mask, column_time] += _pd.to_timedelta(fitfile[mask][column_jitter] / sampling, unit = 's')

    start = min(start_dict.values())
    fitfile[column_timedelta] = fitfile[column_time] - start

    return fitfile


def assign_timedelta(fitfile: _pd.DataFrame,
                     mask: _pd.Series = None,
                     mask_timedelta: _pd.Series = None,
                     suffix: str = None,
                     methods: list = ['pixel', 'channel', 'detector'],
                     column_timedelta: str = 'timedelta') -> _pd.DataFrame:
    if mask is None:
        mask = _pd.Series(True, index = fitfile.index)

    if mask_timedelta is None:
        mask_timedelta = _pd.Series(True, index = fitfile.index)

    for method in methods:
        column = f'{column_timedelta}_{method}'

        if suffix is not None:
            column += f'_{suffix}'

        for mask_method in _masks_method(fitfile, mask, method):
            fitfile = _timedelta_mask(fitfile,
                                      mask = mask_method & mask,
                                      mask_timedelta = mask_method & mask_timedelta,
                                      column_timedelta_mask = column,
                                      column_timedelta = column_timedelta)

    return fitfile


def _timedelta_mask(fitfile: _pd.DataFrame,
                    mask: _pd.Series = None,
                    mask_timedelta: _pd.Series = None,
                    column_timedelta_mask: str = 'timedelta_mask',
                    column_timedelta: str = 'timedelta') -> _pd.DataFrame:
    if mask is None:
        mask = _pd.Series(True, index = fitfile.index)

    if mask_timedelta is None:
        mask_timedelta = _pd.Series(True, index = fitfile.index)

    mask_na = _pd.isna(fitfile[column_timedelta])

    timedelta = fitfile[mask & -mask_na][column_timedelta]
    timedelta = timedelta.sort_values()

    timedelta_signals = fitfile[mask_timedelta & -mask_na][column_timedelta]
    timedelta_signals = timedelta_signals.sort_values()

    current = 0
    current_signal = 0
    previous_signal_time = _pd.Timedelta(0, unit = 'us')

    time_dict = {}

    while current < len(timedelta):
        current_time = timedelta.iloc[current]

        while current_signal < len(timedelta_signals):
            current_signal_time = timedelta_signals.iloc[current_signal]

            if current_time <= current_signal_time:
                time_dict[timedelta.index[current]] = current_time - previous_signal_time
                break
            else:
                previous_signal_time = current_signal_time
                current_signal = current_signal + 1

        if current_signal == len(timedelta_signals):
            time_dict[timedelta.index[current]] = current_time - previous_signal_time

        current = current + 1

    time_series = _pd.Series(time_dict)

    if len(time_series) > 0:
        fitfile.loc[mask, column_timedelta_mask] = time_series

    return fitfile


def assign_coincidence(fitfile: _pd.DataFrame, 
                       mask: _pd.Series = None, 
                       mask_coincidence: _pd.Series = None, 
                       suffix: str = None, 
                       threshold = _pd.Timedelta(10, unit = 'us'), 
                       column_timedelta: str = 'timedelta',
                       column_coincidence: str = 'coincidence') -> _pd.DataFrame:
    if mask is None:
        mask = _pd.Series(True, index = fitfile.index)

    if mask_coincidence is None:
        mask_coincidence = _pd.Series(True, index = fitfile.index)

    column = column_coincidence

    if suffix is not None:
        column += f'_{suffix}'

    fitfile[column] = False

    for _, mask_channel in _masks(fitfile, iterables = ['channel']):
        coincidence = _find_coincidence(fitfile, mask1 = mask & mask_coincidence & -mask_channel, mask2 = mask & mask_channel, threshold = threshold, column_timedelta = column_timedelta)
        fitfile.loc[coincidence.index, column] = _pd.Series(True, index = coincidence.index)
    
    return fitfile


def _find_coincidence(fitfile: _pd.DataFrame,
                      mask1: _pd.Series = None,
                      mask2: _pd.Series = None,
                      threshold: _pd.Timedelta = _pd.Timedelta(10, unit = 'us'),
                      column_timedelta: str = 'timedelta'):
    if mask1 is None:
        mask1 = _pd.Series(True, index = fitfile.index)

    if mask2 is None:
        mask2 = _pd.Series(True, index = fitfile.index)

    series1 = fitfile.loc[mask1, column_timedelta]
    series2 = fitfile.loc[mask2, column_timedelta]

    series1 = series1.sort_values()
    series2 = series2.sort_values()

    i = 0
    j = 0

    coincidence_dict = {}

    while i < len(series1):
        while j < len(series2):
            if series2.iloc[j] - threshold >= series1.iloc[i]:
                break
            if series2.iloc[j] + threshold >= series1.iloc[i]:
                coincidence_dict[series2.index[j]] = series1.index[i]
            j = j + 1
        i = i + 1

    coincidence = _pd.Series(coincidence_dict)

    return coincidence


def information_generate(fitfile, channels, column_time = 'time', threshold = _pd.Timedelta(10, unit = 'us')):
    information = {f'{channel}_index': {} for channel in channels}

    information_channel = {channel: fitfile[fitfile['channel'] == channel][column_time].sort_values() for channel in channels}
    index_channel = {channel: 0 for channel in channels}
    time_channel = {channel: series.iloc[0] if not series.empty else None for channel, series in information_channel.items()}

    information_index = 0

    while any(time is not None for time in time_channel.values()):
        time_current = min((time for time in time_channel.values() if time is not None), default = None)

        for channel in channels:
            time = time_channel[channel]
            if time is not None and time < (time_current + threshold):
                information[f'{channel}_index'][information_index] = information_channel[channel].index[index_channel[channel]]
                index_channel[channel] += 1
                time_channel[channel] = information_channel[channel].iloc[index_channel[channel]] if index_channel[channel] < len(information_channel[channel]) else None

        information_index += 1

    dataframe = _pd.DataFrame(information)

    return dataframe


def information_apply(dataframe, fitfile, mask = None, column: str = 'temperature', column_time: str = 'time', width: _pd.Timedelta = _pd.Timedelta(1, unit = 'ms')):
    if mask is None:
        mask = _pd.Series(True, index = fitfile.index)

    x_fitfile = fitfile[mask][column_time]
    x_dataframe = dataframe[column_time]
    y_dataframe = dataframe[column]
    x0 = _np.min(x_dataframe)

    x = (x_fitfile - x0).apply(_datetime.timedelta.total_seconds)
    xp = (x_dataframe - x0).apply(_datetime.timedelta.total_seconds)
    fp = y_dataframe
    sigma = _datetime.timedelta.total_seconds(width)
    y = _filter_gaussian(x, xp, fp, sigma)

    fitfile.loc[mask, column] = _pd.Series(y, index = x_fitfile.index)

    return fitfile


def information_mean(dataframe, column):
    channels = set(column.split('_')[0] for column in dataframe.columns if column.startswith('ADC'))
    dataframe[column] = dataframe[[f'{channel}_{column}' for channel in channels]].apply(lambda row: row.mean(), axis = 1)

    return dataframe


def information_add(dataframe, fitfile, columns):
    columns = _np.array(columns, ndmin = 1)

    channels = set(column.split('_')[0] for column in dataframe.columns if column.startswith('ADC'))

    for channel in channels:
        for column in columns:
            dataframe[f'{channel}_{column}'] = dataframe[f'{channel}_index'].apply(lambda x: fitfile.loc[x][column] if x in fitfile.index else _np.nan)

    return dataframe


def temperature_correction_calculate(fitfile, column_amplitude = 'amplitude', column_temperature = 'temperature', model = _lmfit.models.PolynomialModel(degree = 1), plot = True):
    mask_temperature = -fitfile[column_temperature].isna()

    channels = _channels(fitfile)
    polarities = _polarities(fitfile)
        
    correction_dict = {}

    for channel in channels:
        mask_channel = (fitfile['channel'] == channel)

        correction_dict[channel] = {}

        for polarity in polarities:
            mask_polarity = (fitfile['polarity'] == polarity)

            mask_pixel = mask_channel & mask_polarity
            mask = mask_pixel & mask_temperature

            if sum(mask) > 0:

                x = fitfile[mask][column_temperature]
                y = fitfile[mask][column_amplitude]

                params = model.guess(y, x = x)
                result = model.fit(y, x = x, params = params)
                
                if plot:
                    xlim = [_np.min(x), _np.max(x)]
                    xfit = _np.linspace(*xlim, 1000)
                    yfit = result.eval(x = xfit)
                    yerr = result.eval_uncertainty(x = xfit, sigma = 1)
                    title = f'{channel} {polarity}'
                
                    _plt.figure()
                    _plt.scatter(x,y, color = 'C0', s = 3)
                    _plt.plot(xfit, yfit, color = 'C1')
                    _plt.fill_between(xfit, yfit - yerr, yfit + yerr, color = 'C1', alpha = 0.5)
                    _plt.xlim(xlim)
                    _plt.xlabel(column_temperature)
                    _plt.ylabel(column_amplitude)
                    _plt.title(title)
                    _plt.tight_layout()
                    _plt.show()
                
                correction_dict[channel][polarity] = result
            
    correction = _pd.DataFrame(correction_dict)

    return correction


def temperature_correction_apply(fitfile, correction, mask = None, column_amplitude = 'amplitude', column_amplitude_corrected = 'amplitude_corrected', column_temperature = 'temperature'):   
    if mask is None:
        mask = _pd.Series(True, index = fitfile.index)
    
    mask_temperature = -fitfile[column_temperature].isna()

    channels = _channels(fitfile)
    polarities = _polarities(fitfile)

    for channel in channels:
        mask_channel = (fitfile['channel'] == channel)
        
        for polarity in polarities:
            mask_polarity = (fitfile['polarity'] == polarity)

            mask_pixel = mask_channel & mask_polarity

            mask_correction = mask_pixel & mask_temperature
            
            if channel in correction and polarity in correction[channel]:
                result = correction[channel][polarity]

                if type(result) != float:
                    fitfile.loc[mask_correction & mask, column_amplitude_corrected] = fitfile[column_amplitude] / result.eval(x = fitfile[column_temperature])
            
    return fitfile


def _odr_poly(B, x):
    return sum(B[i] * x**(i+1) for i in range(len(B)))


def _odr_sigma(x, output):

    cov_beta = output.cov_beta
    num_points = len(x)
    num_coeffs = len(output.beta)

    J = _np.zeros((num_points, num_coeffs))
    for i in range(num_coeffs):
        J[:, i] = x ** (i + 1)
    
    variance = _np.einsum("ij,jk,ik->i", J, cov_beta, J)

    return _np.sqrt(variance)


def energy_calibration_fit(fitfile, column_amplitude = 'amplitude', model = _lmfit.models.VoigtModel(), bins = 100, plot = True):
    channels = _channels(fitfile)
    polarities = _polarities(fitfile)

    results_dict = {}

    for channel in channels:
        mask_channel = (fitfile['channel'] == channel)

        results_dict[channel] = {}

        for polarity in polarities:
            mask_polarity = (fitfile['polarity'] == polarity)

            mask_pixel = mask_channel & mask_polarity
            mask = mask_pixel

            if sum(mask) > 0:

                data = fitfile[mask][column_amplitude]

                peak_range = [_np.min(data), _np.max(data)]
                bins_list = _np.linspace(*peak_range, bins + 1)
                hist, bin_edges = _np.histogram(data, bins = bins_list)
                bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2

                params = model.guess(hist, x = bin_centers)

                if 'gamma' in params:
                    params['gamma'].set(vary = True)
                    
                result = model.fit(hist, x = bin_centers, params = params)
                
                if plot:
                    yfit = result.eval(x = bin_centers)
                    yerr = result.eval_uncertainty(x = bin_centers, sigma = 1)
                    ylim = [0, 1.1 * _np.max(hist)]
                    title = f'{channel} {polarity}'

                    _plt.figure()
                    _plt.scatter(bin_centers, hist, s = 3)
                    _plt.plot(bin_centers, yfit, color = 'C0')
                    _plt.fill_between(bin_centers, yfit - yerr, yfit + yerr, color = 'C0', alpha = 0.5)
                    _plt.xlim(peak_range)
                    _plt.ylim(ylim)
                    _plt.xlabel(column_amplitude)
                    _plt.title(title)
                    _plt.tight_layout()
                    _plt.show()

                results_dict[channel][polarity] = result

    results = _pd.DataFrame(results_dict)
    
    return results


def energy_calibration_calculate(fit_results, calibration_lines, func = _odr_poly, order = 2, pprint = False, plot = True):
    channels = set()
    polarities = set()

    for peak in fit_results:
        result = fit_results[peak]
        channels.update(result.columns)
        polarities.update(result.index)
        
    calibration_dict = {}

    for channel in channels:
        calibration_dict[channel] = {}

        for polarity in polarities:

            x = []
            xerr = []
            y = []
            yerr = []

            for line in fit_results:
                result = fit_results[line][channel][polarity]
                if type(result) != float:
                    center = result.params['center']

                    x.append(center.value)
                    xerr.append(center.stderr)
                    y.append(calibration_lines[line]['value'])
                    yerr.append(calibration_lines[line]['stderr'])

            if len(x) > 0:
                
                x = _np.array(x)
                xerr = _np.array(xerr)
                y = _np.array(y)
                yerr = _np.array(yerr)
                
                model = _odr.Model(_odr_poly)
                data = _odr.RealData(x, y, sx = xerr, sy = yerr)
                beta0 = _np.zeros(order)
                beta0[0] = 1

                odr = _odr.ODR(data, model, beta0=beta0)
                output = odr.run()

                if pprint:
                    output.pprint()

                if plot:
                    xlim = [0, _np.max(x) * 1.1]
                    ylim = [0, _np.max(y) * 1.1]
                    xfit = _np.linspace(*xlim, 1000)
                    yfit = func(output.beta, xfit)
                    ylin = output.beta[0] * xfit
                    ysigma = _odr_sigma(xfit, output)
                    title = f'{channel} {polarity}'

                    fig, ax = _plt.subplots(3, 1)

                    ax[0].errorbar(x, y, xerr=xerr, yerr=yerr, fmt='.', capsize=3, capthick=1, elinewidth=1, color = 'C0', label = "Data")
                    ax[0].plot(xfit, yfit, color = 'C2', label = "Fit")
                    ax[0].fill_between(xfit, yfit - ysigma, yfit + ysigma, color = 'C2', alpha = 0.5)
                    ax[0].set_xlim(xlim)
                    ax[0].set_ylim(ylim)
                    ax[0].set_ylabel("$E$ / keV")
                    ax[0].tick_params(axis = 'x', labelbottom = False)
                    ax[0].set_title(title)
                    ax[0].legend(loc = 2)

                    ax[1].errorbar(x, (y - output.beta[0] * x) * 1e3, xerr = xerr, yerr = yerr * 1e3, fmt = '.', capsize = 3, capthick = 1, elinewidth = 1, color = 'C0', label = "Data")
                    ax[1].plot(xfit, (yfit - ylin) * 1e3, color = 'C2', label = "Fit")
                    ax[1].fill_between(xfit, (yfit - ylin - ysigma) * 1e3, (yfit - ylin + ysigma)  * 1e3, color = 'C2', alpha = 0.5)
                    ax[1].set_xlim(xlim)
                    ax[1].set_ylabel("$\Delta E$ / eV")
                    ax[1].tick_params(axis = 'x', labelbottom = False)

                    ax[2].errorbar(x, (y - func(output.beta, x)) * 1e3, xerr = xerr, yerr = yerr * 1e3, fmt = '.', capsize = 3, capthick = 1, elinewidth = 1, color = 'C0', label = "Data")
                    ax[2].plot(xfit, (yfit - yfit) * 1e3, color = 'C2', label = "Fit")
                    ax[2].fill_between(xfit, (yfit - yfit - ysigma) * 1e3, (yfit - yfit + ysigma)  * 1e3, color = 'C2', alpha = 0.5)
                    ax[2].set_xlim(xlim)
                    ax[2].set_ylabel("$\Delta E$ / eV")
                    ax[2].set_xlabel("$A$ / a.u.")

                    _plt.tight_layout()
                    _plt.show()
            
                calibration_dict[channel][polarity] = output
            
    calibration = _pd.DataFrame(calibration_dict)

    return calibration


def energy_calibration_apply(fitfile, calibration, column_amplitude = 'amplitude', column_energy = 'energy', func = _odr_poly):
    channels = _channels(fitfile)
    polarities = _polarities(fitfile)

    for channel in channels:
        mask_channel = (fitfile['channel'] == channel)

        for polarity in polarities:
            mask_polarity = (fitfile['polarity'] == polarity)
            mask_pixel = mask_channel & mask_polarity

            if channel in calibration and polarity in calibration[channel]:
                output = calibration[channel][polarity]

                if type(output) != float:
                    fitfile.loc[mask_pixel, column_energy] = func(output.beta, fitfile[column_amplitude])
            
    return fitfile


def energy_resolution(fitfile, column_amplitude = 'amplitude', model = _lmfit.models.GaussianModel(), bins = 100, plot = True):
    channels = _channels(fitfile)
    polarities = _polarities(fitfile)

    results_dict = {}

    for channel in channels:
        mask_channel = (fitfile['channel'] == channel)

        results_dict[channel] = {}

        for polarity in polarities:
            mask_polarity = (fitfile['polarity'] == polarity)

            mask_pixel = mask_channel & mask_polarity
            mask = mask_pixel

            if sum(mask) > 0:

                data = fitfile[mask][column_amplitude]

                peak_range = [_np.min(data), _np.max(data)]
                bins_list = _np.linspace(*peak_range, bins + 1)
                hist, bin_edges = _np.histogram(data, bins = bins_list)
                bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2

                params = model.guess(hist, x = bin_centers)
                result = model.fit(hist, x = bin_centers, params = params)
                
                if plot:
                    yfit = result.eval(x = bin_centers)
                    yerr = result.eval_uncertainty(x = bin_centers, sigma = 1)
                    ylim = [0, 1.1 * _np.max(hist)]
                    title = f'{channel} {polarity}'

                    _plt.figure()
                    _plt.scatter(bin_centers, hist, s = 3)
                    _plt.plot(bin_centers, yfit, color = 'C0')
                    _plt.fill_between(bin_centers, yfit - yerr, yfit + yerr, color = 'C0', alpha = 0.5)
                    _plt.xlim(peak_range)
                    _plt.ylim(ylim)
                    _plt.xlabel(column_amplitude)
                    _plt.title(title)
                    _plt.tight_layout()
                    _plt.show()

                results_dict[channel][polarity] = result.best_values['sigma']

    results = _pd.DataFrame(results_dict)
    
    return results