import os as _os
import pandas as _pd
import warnings as _warnings

try:
    from darq.pulseprocessing import PulseReader as _PulseReaderPulseprocessingDARQ
except ImportError:
    _warnings.warn("DARQ is not installed. Please install the `darq` package to read PAQS pulses.", Warning)

from ._signal import Signal as _Signal


class PulseReader:
    def __init__(self, calibration: _pd.DataFrame = None):

        self.calibration = calibration
        self._folder = None
        self._measurement = None
        self._pr = None


    def read(self, entry: _pd.Series) -> _Signal:
        self._set(entry)
        pulse = self._pr.read(entry)

        if self.calibration is not None:
            channel = entry['channel']
            pulse.data = self.calibration[channel][pulse.data]

        return pulse


    def _set(self, entry: _pd.Series) -> None:
        folder = entry['folder']
        measurement = entry['measurement']
        
        if self._pr is None or self._folder != folder or self._measurement != measurement:
            self._folder = folder
            self._measurement = measurement

            path = _os.path.join(folder, measurement)
            path_summary_darq = _os.path.join(path, 'Summary.txt')

            if _os.path.exists(path_summary_darq):
                self._pr = _PulseReaderDARQ()
            else:
                self._pr = _PulseReaderFitfiles()

        return None
    

class _PulseReaderFitfiles:
    def __init__(self, batchsize: int = 1000):

        self._folder = None
        self._measurement = None
        self._channel = None
        self._polarity = None
        self._batchsize = batchsize
        self._batch = None


    def read(self, entry: _pd.Series) -> _Signal:
        self._set(entry)
        signal = entry['signal']

        if self._batch is None:
            self._read_batch(signal)
        elif signal not in self._batch.columns:
            self._read_batch(signal)

        data = self._batch[signal]
        header = entry.to_dict()

        attributes = ['samplingrate', 'oversampling', 'coderange', 'pretrigger']

        for attribute in attributes:
            attribute_generator = 'generator_' + attribute
            if attribute_generator in header:
                header[attribute] = header[attribute_generator]

        pulse = _Signal(data=data, header=header)

        return pulse
    

    def _set(self, entry: _pd.Series) -> None:
        folder = entry['folder']
        measurement = entry['measurement']
        channel = entry['channel']
        polarity = entry['polarity']

        if self._folder != folder or self._measurement != measurement or self._channel != channel or self._polarity != polarity:
            self._batch = None
            self._folder = folder
            self._measurement = measurement
            self._channel = channel
            self._polarity = polarity

        return None
    
    
    def _read_batch(self, signal: int) -> None:
        filename = self._filename_batch(signal)
        if None in [self._folder, self._measurement, self._channel, self._polarity]:
            raise ValueError("folder, measurement, channel, and polarity must not be None.")
        path = _os.path.join(self._folder, self._measurement, self._channel, self._polarity, filename)
        self._batch = _pd.read_parquet(path)

        return None
    

    def _filename_batch(self, signal: int) -> str:
        first = int(signal - signal % self._batchsize)
        last = first + self._batchsize - 1
        filename = f'{first}_{last}.parquet'

        return filename


class _PulseReaderDARQ:
    def __init__(self):

        self._pr = None


    def read(self, entry: _pd.Series) -> _Signal:
        self._set(entry)
        signal = entry['signal']
        pulse_darq = self._pr.readNthPulse(signal)

        data = pulse_darq.signalData

        header_paqs = pulse_darq.headerData
        header_darq = {'samplingrate': pulse_darq.samplingRate,
                       'oversampling': pulse_darq.oversampling,
                       'coderange': 2 ** 16,
                       'pretrigger': pulse_darq.pretrigLength}
        header_fitfiles = entry.to_dict()

        header = self._header(header_paqs, header_darq, header_fitfiles)
        pulse = _Signal(data=data, header=header)

        return pulse


    def _set(self, entry: _pd.Series) -> None:
        folder = entry['folder']
        measurement = entry['measurement']
        channel = entry['channel']
        polarity = entry['polarity']

        path = _os.path.join(folder, measurement)

        if self._pr is None or self._pr.rootPath != path or self._pr.channel != channel or self._pr.polarity != polarity:
            self._pr = _PulseReaderPulseprocessingDARQ(path, channel, polarity)

        return None


    def _header(self, header_paqs: dict, header_darq: dict, header_fitfiles: dict) -> dict:
        header = header_paqs

        drop_list = ['ADC channel no.', 
                     'Pixel no.', 
                     'Signal no. (per ADC channel)', 
                     'Polarity', 
                     'Timestamp', 
                     'NoOfSamples', 
                     'NoOfSamplesOnFastClock', 
                     'name', 
                     'path', 
                     'traceless']
        for entry in drop_list:
            header.pop(entry, None)

        entry_dict = {'Timestamp of previous signal': 'timestamp_pre',
                      'CPU time': 'time_cpu',
                      'Temperature info': 'temperature_info',
                      'Temperature info - uncertainty': 'temperature_info_std',
                      'Signal height': 'height',
                      'Area': 'area',
                      'Characteristic timescale': 'timescale',
                      'Offset': 'offset',
                      'Pulse onset (relative)': 'onset',
                      'AverageCountStatus': 'rate',
                      'ADC': 'temperature_adc',
                      'muonTrig_prev': 'muontrigger_pre',
                      'muonTrig_post': 'muontrigger_post',
                      'muonTrig': 'muontrigger',
                      'comments': 'comment'}
        for entry, new_entry in entry_dict.items():
            if entry in header:
                header[new_entry] = header.pop(entry)

        if 'comment' in header:
            if isinstance(header['comment'], tuple):
                header['comment'] = header['comment'][0]

        type_dict = {'onset' : int,
                     'rate' : int,
                     'timestamp_previous' : int,
                     'time_cpu' : int,
                     'comment' : int}
        for entry, entry_type in type_dict.items():
            if entry in header:
                header[entry] = entry_type(header[entry])

        header.update(header_fitfiles)
        header.update(header_darq)

        return header