import pandas as _pd
import numpy as _np
import os as _os
import matplotlib.pyplot as _plt
import warnings as _warnings
import lmfit as _lmfit

from ._core import (_load_file, 
                    save as _save, 
                    signals as _signals, 
                    multiprocessing as _multiprocessing, 
                    polarities as _polarities, 
                    fitfiles as _fitfiles)

from ._utils import _get_unique_value

from ._frequencyspectrum import Frequencyspectrum as _Frequencyspectrum


def generate(fitfile : _pd.DataFrame,
             calibration : _pd.DataFrame = None,
             pretrigger : int = 4000, 
             multiprocessing : bool = False,
             processes : int = _os.cpu_count()) -> _pd.DataFrame:
    if 'POSP' in _polarities(fitfile) or 'NEGP' in _polarities(fitfile):
        _warnings.warn(f"The fitfile contains `POSP` or `NEGP` polarity.", Warning)

    iterables = _noisespectrum_iterable(fitfile, pretrigger = pretrigger, calibration = calibration)
    noisespectrum_list = _multiprocessing(function = _noisespectrum_channel, iterables = iterables,
                                           multiprocessing = multiprocessing, processes = processes,
                                           desc = 'noisespectra', unit = 'channel')

    noisespectra_dict = {}

    for noisespectrum in noisespectrum_list:
        if noisespectrum.channel not in noisespectra_dict:
            noisespectra_dict[noisespectrum.channel] = {}
        noisespectra_dict[noisespectrum.channel]['BASE'] = noisespectrum

    noisespectra = _pd.DataFrame(noisespectra_dict)

    return noisespectra


def plot(noisespectra : _pd.DataFrame) -> None:
    for channel, entry in noisespectra.items():
        for polarity, noisespectrum in entry.items():

            if isinstance(noisespectrum, _Frequencyspectrum):
                _plt.figure()
                noisespectrum.plot()
                _plt.title(channel + ' ' + polarity)
                _plt.tight_layout()
                _plt.show()
            else:
                _warnings.warn(f"No noisespectra for `{channel}`", Warning)
                      
    return None


def save(noisespectra : _Frequencyspectrum, file : str) -> None:
    _, ext = _os.path.splitext(file)

    if ext != '.pkl':
        raise NotImplementedError("Only `*.pkl` files are supported.")

    _save(noisespectra, file)

    return None


def load(file : str) -> _Frequencyspectrum:
    _, ext = _os.path.splitext(file)

    if ext != '.pkl':
        raise NotImplementedError("Only `*.pkl` files are supported.")

    noisespectra = _load_file(file)

    return noisespectra


def _noisespectrum_iterable(fitfile: _pd.DataFrame, pretrigger: int = 4000, calibration: _pd.DataFrame = None):
    for _, fitfile_channel in _fitfiles(fitfile, iterables = ['channel']):
        yield (fitfile_channel, pretrigger, calibration)


def _noisespectrum_channel(fitfile : _pd.DataFrame, pretrigger = 4000, calibration : _pd.DataFrame = None) -> _Frequencyspectrum:
    frequencyspectra = _noisespectrum_frequencyspectra(fitfile, pretrigger = pretrigger, calibration = calibration)

    noisespectrum = _Frequencyspectrum.noisespectrum(frequencyspectra)

    if noisespectrum.nan():
        _warnings.warn(f"Template for `{noisespectrum.channel}`, `{noisespectrum.polarity}` contains NaN values.", Warning)

    return noisespectrum


def _noisespectrum_frequencyspectra(fitfile : _pd.DataFrame, pretrigger = 4000, calibration : _pd.DataFrame = None):
    for signal in _signals(fitfile, calibration = calibration):
        signal.pretrigger = pretrigger
        signal.normalize(inplace = True)
        frequencyspectrum = signal.frequencyspectrum()

        yield frequencyspectrum