import numpy as _np
import matplotlib.pyplot as _plt
import copy as _copy
import warnings as _warnings
from scipy import fft as _fft, signal as _signal


class Frequencyspectrum:

    def __init__(self, data: _np.ndarray = None, header: dict = None):

        self.data = data
        self.header = header

        self._attributes = ['samplingrate',
                            'oversampling']
        
        self._index = ['folder',
                       'measurement',
                       'channel',
                       'polarity',
                       'signal']

        for attribute in self._attributes + self._index:
            if attribute not in self.header:
                self.header[attribute] = None


    def plot(self, *args, **kwargs) -> None:
        frequency = self.frequency()

        ax = kwargs.pop("ax", _plt.gca())
        ax.loglog(frequency, self.data, *args, **kwargs)
        ax.set_xlim(frequency[1], frequency[-1])
        ax.set_xlabel(r'f / Hz')
        ax.set_ylabel(r'S / Bit$^{2}$ Hz$^{-1}$')

        if "label" in kwargs:
            ax.legend()

        return None
    

    def lowpass(self, cutoff : float, order: int = 1, inplace : bool = False):
        frequencyspectrum = self if inplace else self.copy()

        b, a = _signal.butter(order, cutoff, btype = 'lowpass', analog = True)
        w, h = _signal.freqs(b, a, worN = frequencyspectrum.frequency())

        frequencyspectrum.data *= _np.abs(h)**2

        return frequencyspectrum


    def highpass(self, cutoff : float, order: int = 1, inplace : bool = False):
        frequencyspectrum = self if inplace else self.copy()

        b, a = _signal.butter(order, cutoff, btype = 'highpass', analog = True)
        w, h = _signal.freqs(b, a, worN = frequencyspectrum.frequency())

        frequencyspectrum.data *= _np.abs(h)**2

        return frequencyspectrum
    

    def frequency(self) -> _np.ndarray:
        if self.oversampling is None or self.samplingrate is None:
            raise ValueError("oversampling and samplingrate must not be None")
        
        frequency = _fft.rfftfreq(self.samples(), d=self.oversampling / self.samplingrate)

        return frequency
    

    def interpolate(self, frequency : _np.array, inplace : bool = False):
        frequencyspectrum = self if inplace else self.copy()

        frequencyspectrum.data = _np.interp(frequency, frequencyspectrum.frequency(), frequencyspectrum.data)

        return frequencyspectrum
    

    def noise(self):
        from ._signal import Signal as _Signal

        length = len(self.data)
        phase = _np.random.random_sample(length) * 2 * _np.pi
        data = _fft.irfft(_np.sqrt(self.data * self.samplingrate / self.oversampling * length) * _np.exp(1j * phase))
        header = {"samplingrate": self.samplingrate,
                  "oversampling": self.oversampling,
                  "pretrigger": 0,
                  "folder": self.folder,
                  "measurement": self.measurement,
                  "channel": self.channel,
                  "polarity": self.polarity,
                  "signal": self.signal}

        pulse = _Signal(data = data, header = header)

        return pulse


    def samples(self) -> int:
        length = len(self.data)

        if length != 0:
            samples = (length - 1) * 2
        else:
            samples = 0
        
        return samples


    def copy(self):
        copy = _copy.deepcopy(self)

        return copy
    

    def nan(self) -> bool:
        nan = _np.isnan(self.data).any()

        return nan


    @property
    def data(self) -> _np.ndarray:
        return self._data


    @data.setter
    def data(self, array: _np.ndarray):
        if array is None:
            array = []
        self._data = _np.array(array, ndmin=1, dtype=float)


    @property
    def header(self) -> dict:
        return self._header


    @header.setter
    def header(self, header: dict):
        if header is None:
            header = {}
        self._header = header


    @property
    def samplingrate(self) -> float:
        return self.header.get("samplingrate", None)
    

    @samplingrate.setter
    def samplingrate(self, samplingrate: float):
        self.header["samplingrate"] = samplingrate


    @property
    def oversampling(self) -> int:
        return self.header.get("oversampling", None)


    @oversampling.setter
    def oversampling(self, oversampling: int):
        self.header["oversampling"] = oversampling


    @property
    def folder(self) -> str:
        return self.header.get("folder", None)


    @folder.setter
    def folder(self, folder : str):
        self.header["folder"] = folder


    @property
    def measurement(self) -> str:
        return self.header.get("measurement", None)


    @measurement.setter
    def measurement(self, measurement : str):
        self.header["measurement"] = measurement


    @property
    def channel(self) -> str:
        return self.header.get("channel", None)


    @channel.setter
    def channel(self, channel : str):
        self.header["channel"] = channel


    @property
    def polarity(self)  -> str:
        return self.header.get("polarity", None)


    @polarity.setter
    def polarity(self, polarity : str):
        self.header["polarity"] = polarity


    @property
    def signal(self) -> int:
        return self.header.get("signal", None)


    @signal.setter
    def signal(self, signal : int):
        self.header["signal"] = signal


    def __iter__(self):
        return iter(self.data)


    def __len__(self):
        return len(self.data)


    def __getitem__(self, key):
        return self.data.__getitem__(key)


    def __array__(self):
        return self.data


    def __str__(self) -> str:
        return "Frequencyspectrum"


    def __repr__(self) -> str:
        return f"Frequencyspectrum(folder={self.folder}, measurement={self.measurement}, channel={self.channel}, polarity={self.polarity}, signal={self.signal})"


    def __pos__(self):
        return self.copy()


    def __neg__(self):
        frequencyspectrum = self.copy()
        frequencyspectrum.data = -frequencyspectrum.data
        frequencyspectrum.header = self._combine_header(None)
        return frequencyspectrum


    def __add__(self, other):
        frequencyspectrum = self.copy()
        if isinstance(other, Frequencyspectrum):
            frequencyspectrum.data = self.data + other.data
        else:
            frequencyspectrum.data = self.data + other
        frequencyspectrum.header = self._combine_header(other)
        return frequencyspectrum


    def __sub__(self, other):
        return self.__add__(-other)


    def __mul__(self, other):
        frequencyspectrum = self.copy()
        if isinstance(other, Frequencyspectrum):
            frequencyspectrum.data = self.data * other.data
        else:
            frequencyspectrum.data = self.data * other
        frequencyspectrum.header = self._combine_header(other)
        return frequencyspectrum


    def __truediv__(self, other):
        frequencyspectrum = self.copy()
        if isinstance(other, Frequencyspectrum):
            frequencyspectrum.data = self.data / other.data
        else:
            frequencyspectrum.data = self.data / other
        frequencyspectrum.header = self._combine_header(other)
        return frequencyspectrum


    def __floordiv__(self, other):
        frequencyspectrum = self.copy()
        frequencyspectrum.data = _np.floor(self.__truediv__(other).data)
        frequencyspectrum.header = self._combine_header(other)
        return frequencyspectrum


    def __mod__(self, other):
        frequencyspectrum = self.copy()
        if isinstance(other, Frequencyspectrum):
            frequencyspectrum.data = self.data % other.data
        else:
            frequencyspectrum.data = self.data % other
        frequencyspectrum.header = self._combine_header(other)
        return frequencyspectrum


    def __iadd__(self, other):
        return self.__add__(other)


    def __isub__(self, other):
        return self.__sub__(other)


    def __imul__(self, other):
        return self.__mul__(other)


    def __itruediv__(self, other):
        return self.__truediv__(other)


    def __ifloordiv__(self, other):
        return self.__floordiv__(other)


    def __imod__(self, other):
        return self.__mod__(other)


    __radd__ = __add__
    __rsub__ = __sub__
    __rmul__ = __mul__
    __rtruediv__ = __truediv__
    __rfloordiv__ = __floordiv__
    __rmod__ = __mod__


    def _combine_header(self, other) -> dict:

        header = {key: self.header.get(key, None) for key in self._attributes + self._index}

        if isinstance(other, Frequencyspectrum):
            for entry in header:
                if header[entry] != other.header[entry]:
                    header[entry] = None
                    if entry in self._attributes:
                        _warnings.warn(f"Property `{entry}` is not equal for both frequencyspectra.", Warning)

        return header


    @staticmethod
    def noisespectrum(frequencyspectra):
        header = None
        data = None
        data_mask = None

        _attributes = None
        _index = None

        for frequencyspectrum in frequencyspectra:
            if not isinstance(frequencyspectrum, Frequencyspectrum):
                raise ValueError("All entries must be `Frequencyspectrum` objects.")
            
            if (_attributes is None) or (_index is None):
                _attributes = frequencyspectrum._attributes
                _index = frequencyspectrum._index

            frequencyspectrum_mask = ~_np.isnan(frequencyspectrum.data)

            if (header is None) or (data is None) or (data_mask is None):
                header = {key: frequencyspectrum.header.get(key, None) for key in _attributes + _index}
                data = frequencyspectrum.data
                data_mask = frequencyspectrum_mask.astype(int)
            else:
                for entry in _attributes + _index:
                    if header[entry] != frequencyspectrum.header[entry]:
                        header[entry] = None
                        if entry in _attributes:
                            _warnings.warn(f"Property `{entry}` is not equal for all frequencyspectra.", Warning)
                data = _np.nansum([data, frequencyspectrum.data], axis = 0)
                data_mask += frequencyspectrum_mask.astype(int)

        data /= data_mask

        frequencyspectrum = Frequencyspectrum(data=data, header=header)

        if frequencyspectrum.nan():
            _warnings.warn(f"Noisespectrum contains NaN values.", Warning)

        return frequencyspectrum