import numpy as _np
import matplotlib.pyplot as _plt
import copy as _copy
import warnings as _warnings
from scipy import fft as _fft, signal as _signal
import re as _re
from lmfit import Model as _Model

class Signal:

    def __init__(self, data: _np.ndarray = None, header: dict = None):

        self.data = data
        self.header = header

        self._attributes = ['samplingrate',
                            'oversampling',
                            'coderange',
                            'pretrigger']
        
        self._index = ['folder',
                       'measurement',
                       'channel',
                       'polarity',
                       'signal']

        for attribute in self._attributes + self._index:
            if attribute not in self.header:
                self.header[attribute] = None


    def plot(self, *args, **kwargs) -> None:
        time = self.time()

        ax = kwargs.pop("ax", _plt.gca())
        ax.plot(time, self.data, *args, **kwargs)
        ax.set_xlim(time[0], time[-1])
        ax.set_xlabel(r"t / s")
        ax.set_ylabel(r"code / bit")

        if "label" in kwargs:
            ax.legend()

        return None
    

    def window(self, window: str = ("tukey", 0.1), inplace : bool = False):
        signal = self if inplace else self.copy()

        window_list = _signal.get_window(window, self.samples())
        signal.data *= window_list

        return signal


    def frequencyspectrum(self, window: str = ("tukey", 0.1)):
        from ._frequencyspectrum import Frequencyspectrum as _Frequencyspectrum

        if self.oversampling is None or self.samplingrate is None:
            raise ValueError("oversampling and samplingrate must not be None")
        
        window_list = _signal.get_window(window, self.samples())
        sampling = self.samplingrate / self.oversampling
        data = (2 * _np.abs(_fft.rfft(self.data * window_list)) ** 2 / (sampling * _np.sum(window_list ** 2)))
        header = {"samplingrate": self.samplingrate,
                  "oversampling": self.oversampling,
                  "folder": self.folder,
                  "measurement": self.measurement,
                  "channel": self.channel,
                  "polarity": self.polarity,
                  "signal": self.signal}

        frequencyspectrum = _Frequencyspectrum(data=data, header=header)

        return frequencyspectrum
    

    def filter_matched(self,
                       noisespectrum = None,
                       window: str = ('tukey', 0.1),
                       inplace: bool = False):
        template = self.copy()
        template.window(window, inplace = True)

        filter_matched = self if inplace else self.copy()

        if noisespectrum is None:
            filter_matched.data = _np.flip(template.data, axis = 0)
        else:
            length = len(filter_matched.data)

            if length != noisespectrum.samples():
                sampling = filter_matched.samplingrate / filter_matched.oversampling
                frequency = _fft.rfftfreq(length, d = 1 / sampling)
                noisespectrum = noisespectrum.copy()
                noisespectrum.interpolate(frequency, inplace = True)

            filter_matched.data = _np.real(_fft.irfft(_np.conj(_fft.rfft(filter_matched.data)) / noisespectrum.data))

        filter_matched.data /= _np.sum(template.data * filter_matched.data)

        return filter_matched
    

    def lowpass(self, cutoff : float, order: int = 1, inplace : bool = False):
        signal = self if inplace else self.copy()

        sampling = signal.samplingrate / signal.oversampling
        nyquist = sampling / 2
        sos = _signal.butter(order, cutoff / nyquist, btype = 'lowpass', output='sos')
        signal.data = _signal.sosfilt(sos, signal.data)

        return signal
    

    def highpass(self, cutoff: float, order: int = 1, inplace : bool = False):
        signal = self if inplace else self.copy()

        sampling = signal.samplingrate / signal.oversampling
        nyquist = sampling / 2
        sos = _signal.butter(order, cutoff / nyquist, btype = 'highpass', output = 'sos')
        signal.data = _signal.sosfilt(sos, signal.data)

        return signal
    

    def normalize(self, method: str = 'pretrigger', inplace : bool = False):
        signal = self if inplace else self.copy()

        if method == 'pretrigger':
            signal.data -= signal.nanmean_pretrigger()
        elif method == 'height':
            signal.data /= signal.height()
        elif method == 'area':
            signal.data /= signal.area()
        else:
            raise ValueError("Method must be either `pretrigger`, `height` or `area`.")

        return signal
    

    def interpolate(self, time : _np.array, inplace : bool = False):
        signal = self if inplace else self.copy()

        signal.data = _np.interp(time, self.time(), self.data)

        return signal


    def jitter(self, jitter: float, interpolate: bool = False, inplace : bool = False):
        signal = self if inplace else self.copy()

        jitter_int = int(_np.round(jitter))

        if jitter != 0:
            if interpolate:
                n = len(signal.data)
                xp = _np.arange(n)
                x = xp - jitter
                data = _np.interp(x, xp, signal.data)
            else:
                data = _np.roll(signal.data, jitter_int)

            if jitter_int != 0:
                if jitter > 0:
                    data[:jitter_int] = _np.nan
                else:
                    data[jitter_int:] = _np.nan

            signal.data = data

        return signal
    

    def trim(self, value: int , inplace: bool = False):
        signal = self if inplace else self.copy()

        signal.data = signal.data[value : len(signal.data) - value]
        signal.pretrigger = signal.pretrigger - value

        return signal
    

    def trim_nan(self, value: float, inplace: bool = False):
        signal = self if inplace else self.copy()

        signal.data[:value] = _np.nan
        signal.data[len(signal.data) - value:] = _np.nan

        return signal


    def downsample(self, ratio : int, inplace : bool = False):
        if len(self.data) % ratio != 0:
            _warnings.warn("The length of the signal data must be a multiple of the downsampling ratio.", Warning)

        signal = self if inplace else self.copy()

        signal.data = signal.data.reshape(-1, ratio)
        signal.data = signal.data.mean(axis = 1)
        signal.oversampling = signal.oversampling // ratio

        return signal
    

    def digitize(self, inplace : bool = False):
        signal = self if inplace else self.copy()

        signal.data = _np.round(signal.data, decimals = 0)
        signal.data = _np.clip(signal.data, 0, self.coderange - 1)
        signal.data = signal.data.astype(int)

        return signal
    

    def shape(self, risetimes: int = 1, decaytimes: int = 1) -> dict:

        from .generator import _pulse

        def _add_exp(params, method):

            number_set = set([int(_re.findall(r'\d+', entry)[0]) for entry in params if entry.startswith(method)])
            number_last = max(number_set)

            if f'{method}{number_last}_factor' not in params:
                params.add(f'{method}{number_last}_factor', value = 1, min = 0, max = 1)

            params[f'{method}{number_last}'].set(vary=True, min=10e-9)
            params[f'{method}{number_last}_factor'].set(vary=True, min=0, max=1)

            number_highest = None
            for number in number_set:
                if number_highest is None or params[f'{method}{number}_factor'] > number_highest:
                    number_highest = number

            value = params[f'{method}{number_highest}'].value
            factor = params[f'{method}{number_highest}_factor'].value

            params[f'{method}{number_highest}'].set(value = value / 2)
            params[f'{method}{number_highest}_factor'].set(value = factor / 2)

            expr = '1'
            for i in number_set:
                expr += f' - {method}{i}_factor'

            params.add(f'{method}{number_last + 1}', value = value * 2, vary=True, min=10e-9)
            params.add(f'{method}{number_last + 1}_factor', value = value / 2, vary=True, min=0, max=1, expr=expr)

            return params

        x = self.time()
        y = self.data

        amplitude = self.height() * self.sign()
        offset = self.mean_pretrigger()

        if _np.abs(offset) < self.std_pretrigger():
            offset = 0.0

        model = _Model(_pulse)

        params = model.make_params()
        params.add('amplitude', value = amplitude)
        params.add('jitter', value = 0)
        params.add('rise0', value = 10e-6, min = 10e-9)
        params.add('decay0', value = 10e-3, min = 10e-9)
        params.add('offset', value = offset)

        result = model.fit(y, x = x, params=params)

        for _ in range(decaytimes - 1):
            params = _add_exp(result.params, 'decay')
            result = model.fit(y, x = x, params=params)

        for _ in range(risetimes - 1):
            params = _add_exp(result.params, 'rise')
            result = model.fit(y, x = x, params=params)

        params = result.params

        shape = {key: params[key].value for key in params}
        shape['chi2'] = result.chisqr / len(y)

        return shape


    def fom(self) -> float:
        fom = self.height() / self.std_pretrigger()

        return fom


    def clipped(self) -> bool:
        clipped = (self.data.min() <= 0) or (self.data.max() >= self.coderange - 1)

        return clipped


    def time(self) -> _np.ndarray:
        time = (_np.arange(len(self.data)) - self.pretrigger) / (self.samplingrate / self.oversampling)

        return time


    def samples(self) -> int:
        samples = len(self.data)
        
        return samples


    def traceless(self) -> bool:
        traceless = (len(self.data) == 0)
        
        return traceless
    

    def nan(self) -> bool:
        nan = _np.isnan(self.data).any()

        return nan


    def data_pretrigger(self) -> _np.ndarray:
        data_pretrigger = self.data[: self.pretrigger]

        return data_pretrigger


    def time_pretrigger(self) -> _np.ndarray:
        time_pretrigger = self.time()[: self.pretrigger]

        return time_pretrigger


    def height(self) -> float:
        mean_pretrigger = self.mean_pretrigger()
        height = _np.max([mean_pretrigger - self.data.min(), self.data.max() - mean_pretrigger])

        return height
    

    def sign(self) -> int:
        sign = _np.sign(_np.nanmedian(self.data - _np.nanmean(self.data_pretrigger())))

        return sign


    def area(self) -> float:
        area = _np.sum(self.data - self.mean_pretrigger())

        return area
    

    def max_difference(self) -> float:
        max_difference = _np.max(_np.abs(_np.diff(self.data)))

        return max_difference
    

    def slope(self) -> float:
        time = self.time()
        slope = (_np.mean(time * self.data) - _np.mean(time) * self.mean()) / (_np.mean(time**2) - _np.mean(time)**2)

        return slope


    def slope_pretrigger(self) -> float:
        time = self.time_pretrigger()
        slope_pretrigger = (_np.mean(time * self.data_pretrigger()) - _np.mean(time) * self.mean_pretrigger()) / (_np.mean(time**2) - _np.mean(time)**2)

        return slope_pretrigger
    

    def copy(self):
        copy = _copy.deepcopy(self)

        return copy
    

    def max(self) -> float:
        max = _np.max(self.data)

        return max
    

    def nanmax(self) -> float:
        nanmax = _np.nanmax(self.data)

        return nanmax


    def min(self) -> float:
        min = _np.min(self.data)

        return min
    

    def nanmin(self) -> float:
        nanmin = _np.nanmin(self.data)

        return nanmin


    def argmax(self) -> int:
        argmax = _np.argmax(self.data)

        return argmax


    def argmin(self) -> int:
        argmin = _np.argmin(self.data)

        return argmin


    def mean(self) -> float:
        mean = _np.mean(self.data)

        return mean
    

    def nanmean(self) -> float:
        nanmean = _np.nanmean(self.data)

        return nanmean


    def mean_pretrigger(self) -> float:
        mean_pretrigger = _np.mean(self.data_pretrigger())

        return mean_pretrigger
    

    def nanmean_pretrigger(self) -> float:
        nanmean_pretrigger = _np.nanmean(self.data_pretrigger())

        return nanmean_pretrigger


    def std(self) -> float:
        std = _np.std(self.data)

        return std
    

    def nanstd(self) -> float:
        nanstd = _np.nanstd(self.data)

        return nanstd


    def std_pretrigger(self) -> float:
        std_pretrigger = _np.std(self.data_pretrigger())

        return std_pretrigger
    

    def nanstd_pretrigger(self) -> float:
        nanstd_pretrigger = _np.nanstd(self.data_pretrigger())

        return nanstd_pretrigger


    def var(self) -> float:
        var = _np.var(self.data)

        return var


    def sum(self) -> float:
        sum = _np.sum(self.data)

        return sum
    

    def nansum(self) -> float:
        nansum = _np.nansum(self.data)

        return nansum


    @property
    def data(self) -> _np.ndarray:
        return self._data


    @data.setter
    def data(self, array: _np.ndarray):
        if array is None:
            array = []
        self._data = _np.array(array, ndmin=1, dtype=float)


    @property
    def header(self) -> dict:
        return self._header


    @header.setter
    def header(self, header: dict):
        if header is None:
            header = {}
        self._header = header


    @property
    def samplingrate(self) -> float:
        return self.header.get("samplingrate", None)


    @samplingrate.setter
    def samplingrate(self, samplingrate: float):
        self.header["samplingrate"] = samplingrate


    @property
    def oversampling(self) -> int:
        return self.header.get("oversampling", None)


    @oversampling.setter
    def oversampling(self, oversampling: int):
        self.header["oversampling"] = oversampling


    @property
    def coderange(self) -> int:
        return self.header.get("coderange", None)


    @coderange.setter
    def coderange(self, coderange : int):
        self.header["coderange"] = coderange


    @property
    def pretrigger(self) -> int:
        return self.header.get("pretrigger", None)


    @pretrigger.setter
    def pretrigger(self, pretrigger : int):
        self.header["pretrigger"] = pretrigger


    @property
    def folder(self) -> str:
        return self.header.get("folder", None)


    @folder.setter
    def folder(self, folder : str):
        self.header["folder"] = folder


    @property
    def measurement(self) -> str:
        return self.header.get("measurement", None)


    @measurement.setter
    def measurement(self, measurement : str):
        self.header["measurement"] = measurement


    @property
    def channel(self) -> str:
        return self.header.get("channel", None)


    @channel.setter
    def channel(self, channel : str):
        self.header["channel"] = channel


    @property
    def polarity(self)  -> str:
        return self.header.get("polarity", None)


    @polarity.setter
    def polarity(self, polarity : str):
        self.header["polarity"] = polarity


    @property
    def signal(self) -> int:
        return self.header.get("signal", None)


    @signal.setter
    def signal(self, signal : int):
        self.header["signal"] = signal


    def __iter__(self):
        return iter(self.data)


    def __len__(self):
        return len(self.data)


    def __getitem__(self, key):
        return self.data.__getitem__(key)


    def __array__(self):
        return self.data


    def __str__(self) -> str:
        return "Signal"


    def __repr__(self) -> str:
        return f"Signal(folder={self.folder}, measurement={self.measurement}, channel={self.channel}, polarity={self.polarity}, signal={self.signal})"


    def __pos__(self):
        signal = self.copy()
        return signal


    def __neg__(self):
        signal = self.copy()
        signal.data = -signal.data
        signal.header = self._combine_header(None)
        return signal


    def __add__(self, other):
        signal = self.copy()
        if isinstance(other, Signal):
            signal.data = self.data + other.data
        else:
            signal.data = self.data + other
        signal.header = self._combine_header(other)
        return signal


    def __sub__(self, other):
        return self.__add__(-other)


    def __mul__(self, other):
        signal = self.copy()
        if isinstance(other, Signal):
            signal.data = self.data * other.data
        else:
            signal.data = self.data * other
        signal.header = self._combine_header(other)
        return signal


    def __truediv__(self, other):
        signal = self.copy()
        if isinstance(other, Signal):
            signal.data = self.data / other.data
        else:
            signal.data = self.data / other
        signal.header = self._combine_header(other)
        return signal


    def __floordiv__(self, other):
        signal = self.copy()
        signal.data = _np.floor(self.__truediv__(other).data)
        signal.header = self._combine_header(other)
        return signal


    def __mod__(self, other):
        signal = self.copy()
        if isinstance(other, Signal):
            signal.data = self.data % other.data
        else:
            signal.data = self.data % other
        signal.header = self._combine_header(other)
        return signal
    

    def __pow__(self, other):
        signal = self.copy()
        if isinstance(other, Signal):
            signal.data = self.data ** other.data
        else:
            signal.data = self.data ** other
        signal.header = self._combine_header(other)
        return signal


    def __iadd__(self, other):
        return self.__add__(other)


    def __isub__(self, other):
        return self.__sub__(other)


    def __imul__(self, other):
        return self.__mul__(other)


    def __itruediv__(self, other):
        return self.__truediv__(other)


    def __ifloordiv__(self, other):
        return self.__floordiv__(other)


    def __imod__(self, other):
        return self.__mod__(other)


    __radd__ = __add__
    __rsub__ = __sub__
    __rmul__ = __mul__
    __rtruediv__ = __truediv__
    __rfloordiv__ = __floordiv__
    __rmod__ = __mod__
    

    def _combine_header(self, other) -> dict:

        header = {key: self.header.get(key, None) for key in self._attributes + self._index}

        if isinstance(other, Signal):
            for entry in header:
                if header[entry] != other.header[entry]:
                    header[entry] = None
                    if entry in self._attributes:
                        _warnings.warn(f"Property `{entry}` is not equal for both signals.", Warning)

        return header
    

    @staticmethod
    def template(signals):
        header = None
        data = None
        data_mask = None

        _attributes = None
        _index = None

        for signal in signals:
            if not isinstance(signal, Signal):
                raise ValueError("All entries must be `signal` objects.")
            
            if (_attributes is None) or (_index is None):
                _attributes = signal._attributes
                _index = signal._index

            signal_mask = ~_np.isnan(signal.data)
            
            if (header is None) or (data is None) or (data_mask is None):
                header = {key: signal.header.get(key, None) for key in _attributes + _index}
                data = signal.data
                data_mask = signal_mask.astype(int)
            else:
                for entry in _attributes + _index:
                    if header[entry] != signal.header[entry]:
                        header[entry] = None
                        if entry in _attributes:
                            _warnings.warn(f"Property `{entry}` is not equal for all signals.", Warning)
                data = _np.nansum([data, signal.data], axis = 0)
                data_mask += signal_mask.astype(int)

        data /= data_mask

        signal = Signal(data=data, header=header)

        if signal.nan():
            _warnings.warn(f"Template contains NaN values.", Warning)

        return signal