import pandas as _pd
import numpy as _np
import os as _os
from scipy import fft as _fft
import warnings as _warnings

from pathlib import Path as _Path

from ._core import (multiprocessing as _multiprocessing, 
                    append as _append, 
                    dtypes as _dtypes, 
                    fitfiles as _fitfiles,
                    save as _save,
                    affix as _affix)

from ._utils import _get_distribution, _get_parameter, _yield_prefix, _get_unique_value

from ._frequencyspectrum import Frequencyspectrum as _Frequencyspectrum
from ._signal import Signal as _Signal


def load(folders):
    folders = _np.array(folders, ndmin = 1)
    paths_set = set()

    filename = "fitfile.parquet"

    for folder in folders:
        if _os.path.exists(folder):
            for dirpath, _, filenames in _os.walk(_Path(folder)):
                if filename in filenames:
                    path = _os.path.join(dirpath, filename)
                    paths_set.add(path)
                    
        else:
            _warnings.warn(f"`{folder}` does not exist.", Warning)

    paths = _np.array(list(paths_set))

    return paths


def random(name: str, *args, **kwargs):

    if name == 'uniform':
        def function():
            return _np.random.uniform(args[0], args[1])
        
    elif name == 'gauss':
        def function():
            return args[0] + args[1] * _np.random.standard_normal()

    elif name == 'lorentz':
        def function():
            return args[0] + args[1] * _np.random.standard_cauchy()
        
    else:
        raise ValueError(f'Random function {name} is not implemented.')

    return function


def measurement(folder: str = 'data',
                measurement: str = 'measurement',
                signals: dict = {'ADC1': {'POSP': {'signals': 1e3,
                                                    'offset': 0.0,
                                                    'pulse0_amplitude': 0.5,
                                                    'pulse0_jitter': 0.0,
                                                    'pulse0_rise': 1e-6,
                                                    'pulse0_decay': 1e-3},
                                           'BASE': {'signals': 1e3,
                                                    'offset': 0.0}}},
                noise: dict = {'ADC1': {'noise0_white': (0.3e-6)**2,
                                         'noise0_pink': (5.0e-6)**2,
                                         'noise0_exponent': 0.8}},
                settings: dict = {'samples': 2**14,
                                   'oversampling': 32,
                                   'subsampling': 8,
                                   'samplingrate': 1e8,
                                   'pretrigger': 4000,
                                   'coderange': 2**16,
                                   'flux_conversion': 1.8e-5,
                                   'bandwidth': 1e6},
                multiprocessing: bool = False,
                processes: int = _os.cpu_count(),
                **kwargs: dict) -> _pd.DataFrame:
    
    path_fitfile = _os.path.join(folder, measurement, 'fitfile.parquet')

    if _os.path.exists(path_fitfile):
        measurement = _os.path.join(folder, measurement)
        raise FileExistsError(f'Measurement {measurement} already exists.')
    
    if settings['subsampling'] > settings['oversampling'] / 2:
        raise ValueError('Subsampling cannot be larger than half the oversampling.')

    fitfiles = []

    for channel in signals.keys():
        if channel in noise.keys():
            noise_channel = noise[channel]
        else:
            noise_channel = {}

        for polarity in signals[channel].keys():
            signals_channel = signals[channel][polarity]
            fitfile = _fitfile(folder, measurement, channel, polarity, **signals_channel, **noise_channel, **settings, **kwargs)
            fitfiles.append(fitfile)

    fitfile = _append(fitfiles)

    noisespectra = _noisespectra(fitfile)

    _generate(fitfile, noisespectra = noisespectra, multiprocessing = multiprocessing, processes = processes)
    
    fitfile = _dtypes(fitfile, category = True)
    fitfile = _affix(fitfile, affix = 'generator', method = 'prefix')

    _save(fitfile, path_fitfile)

    return fitfile


def noisespectrum(samples: int,
                  samplingrate: float,
                  oversampling: int,
                  flux_conversion: float,
                  folder: str = None,
                  channel: str = None,
                  polarity: str = 'BASE',
                  signal: int = None,
                  plot: bool = False,
                  **kwargs) -> _Frequencyspectrum:

    sampling = samplingrate / oversampling
    f = _fft.rfftfreq(samples, 1 / sampling)

    data = _np.zeros_like(f)
    header = {'samplingrate' : samplingrate,
              'oversampling' : oversampling,
              'folder' : folder,
              'channel' : channel,
              'polarity' : polarity,
              'signal' : signal}
    noisespectrum = _Frequencyspectrum(data = data, header = header)
    
    for prefix in _yield_prefix(kwargs, 'noise'):
        parameter = _get_parameter(kwargs, prefix)

        data = _noise_component(f[1:], **parameter)
        data = _np.insert(data, 0, 0)
        data /= flux_conversion**2

        frequencyspectrum = _Frequencyspectrum(data = data, header = header)
        noisespectrum += frequencyspectrum

    if 'bandwidth' in kwargs.keys():
        bandwidth = kwargs['bandwidth']
        if bandwidth is not None and not _np.isnan(bandwidth):
            noisespectrum.lowpass(bandwidth, order = 2, inplace = True)

    if plot:
        noisespectrum.plot()
    
    return noisespectrum


def signal(samples: int, 
          oversampling: int, 
          samplingrate: float, 
          coderange: int, 
          pretrigger: int, 
          flux_conversion: float,
          folder: str = None, 
          measurement: str = None, 
          channel: str = None, 
          polarity: str = None, 
          signal: int = None,
          plot: bool = False,
          **kwargs) -> _Signal:

    data = _np.zeros(samples)
    header = {'folder' : folder,
              'measurement' : measurement,
              'channel' : channel,
              'polarity' : polarity,
              'signal' : signal,
              'samplingrate' : samplingrate,
              'oversampling' : oversampling,
              'coderange' : coderange,
              'pretrigger' : pretrigger}
    signal = _Signal(data = data, header = header)
    
    time = signal.time()

    for prefix in _yield_prefix(kwargs, 'pulse'):
        parameter = _get_parameter(kwargs, prefix)

        data = _pulse(time, **parameter)
        data /= flux_conversion

        pulse = _Signal(data = data, header = header)
        signal += pulse

    if plot:
        signal.plot()

    return signal


def _fitfile(folder: str, measurement: str, channel: str, polarity: str, signals: int, **kwargs) -> _pd.DataFrame:

    parameter = _dataframe_parameter(signals, channel = channel, polarity = polarity, **kwargs)
    parameter = _index_parameter(parameter, folder, measurement)

    fitfile = _dtypes(parameter)

    return fitfile


def _dataframe_parameter(signals: int, **kwargs) -> _pd.DataFrame:

    pulses = []

    for _ in range(int(signals)):
        pulse = {}
        for name, input in kwargs.items():
            distribution = _get_distribution(input)
            pulse[name] = distribution()
        pulses.append(pulse)

    parameter = _pd.DataFrame(pulses)

    return parameter


def _index_parameter(parameters: _pd.DataFrame, folder: str, measurement: str) -> _pd.DataFrame:

    channels = set(parameters['channel'])
    for channel in channels:
        mask_channel = (parameters['channel'] == channel) 
        polarities = set(parameters[mask_channel]['polarity'])
        for polarity in polarities:
            mask_polarity = (parameters['polarity'] == polarity) 
            mask_pixel = mask_channel & mask_polarity
            
            parameters.loc[mask_pixel, 'signal'] = range(sum(mask_pixel))
           
    parameters['folder'] = folder 
    parameters['measurement'] = measurement

    return parameters


def _noisespectra(fitfile: _pd.DataFrame, plot: bool = False) -> _pd.DataFrame:

    noisespectra_dict = {}

    for channel, fitfile_channel in _fitfiles(fitfile, iterables = ['channel']):
        noisespectrum_channel = _noise_channel(fitfile_channel, channel = channel, plot = plot)
        noisespectra_dict[channel] = {'BASE' : noisespectrum_channel}

    noisespectra = _pd.DataFrame(noisespectra_dict)

    return noisespectra


def _noise_channel(fitfile_channel: _pd.DataFrame, channel: str = None, plot: bool = False) -> _Frequencyspectrum:

    samplingrate = _get_unique_value(fitfile_channel, 'samplingrate')
    oversampling = _get_unique_value(fitfile_channel, 'oversampling')
    samples = _get_unique_value(fitfile_channel, 'samples')
    flux_conversion = _get_unique_value(fitfile_channel, 'flux_conversion')
    subsampling = _get_unique_value(fitfile_channel, 'subsampling')

    settings = {'samples': samples * int(oversampling / subsampling),
                'samplingrate': samplingrate,
                'oversampling': subsampling,
                'channel' : channel,
                
                'polarity' : 'BASE'}
    
    for prefix in _yield_prefix(fitfile_channel, 'noise'):
        parameter = _get_parameter(fitfile_channel, prefix)
        parameter_prefix = {f'{prefix}_{key}' : value for key, value in parameter.items()}
        settings.update(parameter_prefix)

    noisespectrum_channel = noisespectrum(flux_conversion = flux_conversion,
                                          plot = plot,
                                          bandwidth = None,
                                          **settings)
    
    return noisespectrum_channel

    
def _noise_component(frequency: _np.array, **kwargs) -> _np.array:

    noisecomponent = _np.zeros(len(frequency))

    if 'white' in kwargs.keys():
        noisecomponent += kwargs['white'] * _np.ones(len(frequency))

    if 'pink' in kwargs.keys():
        noisecomponent += kwargs['pink'] / frequency**kwargs['exponent']

    if 'cutoff' in kwargs.keys():
        noisecomponent *= _np.ones(len(frequency)) / (1 + (frequency / kwargs['cutoff'])**2)
    
    return noisecomponent


def _generate(fitfile: _pd.DataFrame, noisespectra: _pd.DataFrame = None, multiprocessing: bool = False, processes: int = _os.cpu_count()) -> None:

    iterables = _generate_iterables(fitfile, noisespectra)
    _ =_multiprocessing(function = _generate_batch,
                        iterables = iterables,
                        multiprocessing = multiprocessing,
                        processes = processes,
                        desc = 'generator',
                        unit = 'batch')

    return None


def _generate_iterables(fitfile: _pd.DataFrame, noisespectra: _pd.DataFrame):

    for folder, measurement, channel, polarity, fitfile_pixel in _fitfiles(fitfile, iterables = ['folder', 'measurement', 'channel', 'polarity']):

        try:
            noisespectrum = noisespectra[channel]['BASE']
        except:
            noisespectrum = None

        filefolder = _os.path.join(folder, measurement, channel, polarity)

        if not _os.path.exists(filefolder):
            _os.makedirs(filefolder)

        for batch_pixel, filename in _batch_pixel(fitfile_pixel):
            yield (batch_pixel, noisespectrum, filefolder, filename)


def _generate_batch(batch: _pd.DataFrame, noisespectrum: _Frequencyspectrum, filefolder: str, filename: str) -> None:

    pulses_batch = _calculate_batch(batch, noisespectrum)
    filepath = _os.path.join(filefolder, filename)
    pulses_batch.to_parquet(filepath, compression = 'brotli')

    return None


def _batch_pixel(fitfile_pixel: _pd.DataFrame, batchsize: int = 1000):

    first = 0

    while len(fitfile_pixel) > 0:

        mask = (fitfile_pixel['signal'] < first + batchsize)

        if sum(mask) > 0:
            batch_pixel = fitfile_pixel[mask]
            fitfile_pixel = fitfile_pixel[-mask]
            filename = str(first)+ '_' + str(first + batchsize - 1) + '.parquet'
            yield batch_pixel, filename

        first = first + batchsize


def _calculate_batch(batch: _pd.DataFrame, noisespectrum: _Frequencyspectrum) -> _pd.DataFrame:

    series = []

    for index in batch.index:
        entry = batch.loc[index]
        signal_data = _calculate_signal(entry, noisespectrum)

        serie = _pd.Series(signal_data, name = entry['signal'])
        series.append(serie)

    pulses_batch = _pd.concat(series, axis=1)

    return pulses_batch


def _calculate_signal(entry: _pd.Series, noisespectrum: _Frequencyspectrum) -> _np.array:

    
    samplingrate = entry['samplingrate']
    oversampling = entry['oversampling']
    subsampling = entry['subsampling']
    coderange = entry['coderange']
    flux_conversion = entry['flux_conversion']
    sampling_factor = int(oversampling / subsampling)
    samples = entry['samples'] * sampling_factor
    pretrigger = entry['pretrigger'] * sampling_factor

    data = _np.zeros(samples)
    header = {'samplingrate': samplingrate,
              'oversampling': subsampling,
              'coderange': coderange,
              'pretrigger': pretrigger}
    trace = _Signal(data = data, header = header)
    
    keys = entry.keys()

    settings = header.copy()
    settings.update({'samples': samples, 'flux_conversion': flux_conversion})

    for prefix in _yield_prefix(keys, 'pulse'):
        parameter = _get_parameter(entry, prefix)
        parameter['offset'] = None
        parameter_prefix = {f'{prefix}_{key}' : value for key, value in parameter.items()}
        settings.update(parameter_prefix)

    trace += signal(**settings)

    if entry['polarity'] == 'NEGP':
        trace = -trace

    if noisespectrum is not None:
        noise = noisespectrum.noise()
        noise.coderange = entry['coderange']
        trace += noise

    if 'bandwidth' in keys:
        bandwidth = entry['bandwidth']
        if bandwidth is not None and not _np.isnan(bandwidth):
            trace.lowpass(bandwidth, order = 2, inplace = True)

    if 'offset' in keys:
        offset = entry['offset']
        if offset is not None and not _np.isnan(offset):
            trace += offset / entry['flux_conversion']

    trace += entry['coderange'] / 2

    trace.digitize(inplace = True)

    trace.downsample(sampling_factor, inplace = True) 

    trace.data = trace.data.astype(int)

    return trace


def _pulse(x: _np.array, **kwargs) -> _np.array:
    length = len(x)
    y = _np.zeros(length)

    if 'amplitude' in kwargs:
        amplitude = kwargs['amplitude']

        if amplitude is not None and not _np.isnan(amplitude):

            if 'jitter' in kwargs.keys():
                jitter = kwargs['jitter']

                if jitter is not None and not _np.isnan(jitter):
                    x = x - jitter

            rise = _pulse_exp(x[x > 0], 'rise', **kwargs)
            decay = _pulse_exp(x[x > 0], 'decay', **kwargs)

            y[x > 0] += amplitude * (decay - rise)

    if 'offset' in kwargs.keys():
        offset = kwargs['offset']
        if offset is not None and not _np.isnan(offset):
            y += offset

    return y


def _pulse_exp(x: _np.array, method: str, **kwargs) -> _np.array:
    length = len(x)

    if method == 'rise':
        y = _np.zeros(length)
    elif method == 'decay':
        y = _np.ones(length)
    else:
        raise ValueError(f'Method {method} is not valid.')

    columns = kwargs.keys()
    for prefix in _yield_prefix(columns, method):
    
        rate = kwargs[prefix]
        if rate is not None and not _np.isnan(rate):

            if prefix + '_factor' in columns:
                factor = kwargs[prefix + '_factor']

                if factor is None or _np.isnan(factor):
                    factor = 1
            else:
                factor = 1

            if method == 'rise':
                y += _np.exp(-x / rate) * factor
            elif method == 'decay':
                y += _np.exp(-x / rate) * factor - factor

    return y