import pandas as _pd
import numpy as _np
import os as _os
import warnings as _warnings
import math as _math
import multiprocessing as _mp
from tqdm import tqdm as _tqdm

from pathlib import Path as _Path

from ._pulsereader import PulseReader as _PulseReader
from ._signal import Signal as _Signal


def multiprocessing(function, 
                    iterables: list, 
                    multiprocessing: bool = False, 
                    processes: int = _os.cpu_count(), 
                    desc: str = '', 
                    unit: str = 'it') -> list:
    if multiprocessing:
        with _mp.Pool(processes) as pool:
            results = pool.starmap(function, iterables)

    else:
        results = []
        for iterable in _tqdm(list(iterables), desc=desc, unit=unit):
            results.append(function(*iterable))

    return results


def signal(entry: _pd.Series, 
           pr: _PulseReader = None, 
           calibration: _pd.DataFrame = None) -> _Signal:
    if pr is None:
        pr = _PulseReader()

    if calibration is not None:
        pr.calibration = calibration

    pulse = pr.read(entry)

    return pulse


def signals(fitfile: _pd.DataFrame, 
            calibration: _pd.DataFrame = None):
    pr = _PulseReader(calibration = calibration)

    for index in fitfile.index:
        entry = fitfile.loc[index]
        pulse = signal(entry, pr = pr)
        
        yield pulse


def extend(fitfiles: list, 
           prefix: list = None,
           suffix: list = None) -> _pd.DataFrame:
    index_columns = ['folder', 'measurement', 'channel', 'polarity', 'signal']

    if prefix is not None:
        fitfiles = affix(fitfiles, prefix, method = 'prefix')

    if suffix is not None:
        fitfiles = affix(fitfiles, suffix, method = 'suffix')

    for fitfile in fitfiles:
        fitfile.set_index(index_columns, inplace=True)

    columns_set = set()

    for fitfile in fitfiles:
        columns_set = columns_set | set(fitfile.columns)

    columns_set -= set(index_columns)

    for column in columns_set:
        data = None
        for fitfile in fitfiles:
            if column in fitfile.columns:
                if data is None:
                    data = fitfile[column]
                else:
                    if not _np.all(data.align(fitfile[column])):
                        _warnings.warn(f"Column {column} exists multiple times and has different values.", Warning)
                    fitfile.drop(column, axis=1, inplace=True)

    fitfile = _pd.concat(fitfiles, axis=1)
    fitfile.reset_index(inplace=True)
    fitfile = dtypes(fitfile)

    return fitfile


def append(fitfiles: list) -> _pd.DataFrame:
    category_list = [set(fitfile.select_dtypes(include='category').columns) for fitfile in fitfiles]

    for column in set.intersection(*category_list):
        category_union = _pd.api.types.union_categoricals([fitfile[column] for fitfile in fitfiles])
    
        for fitfile in fitfiles:
            fitfile[column] = _pd.Categorical(fitfile[column].values, categories=category_union.categories)

    fitfile = _pd.concat(fitfiles, axis=0, ignore_index=True)
    fitfile = dtypes(fitfile)

    return fitfile


def affix(fitfiles: list, 
          affix: list, 
          method: str = 'prefix'):
    if isinstance(fitfiles, _pd.DataFrame):
        fitfiles = affix_fitfile(fitfiles, affix, method = method)

    else:
        if isinstance(affix, str):
            affix = [affix] * len(fitfiles)

        for i in range(len(fitfiles)):
            if affix[i] is not None:
                fitfiles[i] = affix_fitfile(fitfiles[i], affix[i], method = method)

    return fitfiles


def affix_fitfile(fitfile: _pd.DataFrame, 
                  affix: str, 
                  method: str = 'prefix') -> _pd.DataFrame:
    if affix is not None:
        columnsDict = {}

        for column in fitfile.columns:
            if column not in ['folder', 'measurement', 'channel', 'polarity', 'signal', 'timestamp']:
                if method == 'prefix':
                    columnsDict[column] = affix + '_' + column
                elif method == 'suffix':
                    columnsDict[column] =  column + '_' + affix
                else:
                    raise NotImplementedError

        fitfile = fitfile.rename(columns = columnsDict)

    return fitfile


def load(files: list) -> _pd.DataFrame:
    files = _np.array(files, ndmin = 1)

    fitfiles = []

    for file in files:
        fitfile = _load_file(file)
        fitfiles.append(fitfile)

    fitfile = extend(fitfiles)
    
    return fitfile


def _load_file(file: str) -> _pd.DataFrame:
    file = _Path(file)

    _, ext = _os.path.splitext(file)

    if ext == '.pkl':
        fitfile = _pd.read_pickle(file)
    elif ext == '.csv':
        fitfile = _pd.read_csv(file)
    elif ext == '.parquet':
        fitfile = _pd.read_parquet(file)
    else:
        raise NotImplementedError

    return fitfile


def save(fitfile: _pd.DataFrame, 
         path: str) -> None:
    path = _Path(path)

    _, ext = _os.path.splitext(path)

    if ext == '.pkl':
        fitfile.to_pickle(path)
    elif ext == '.csv':
        fitfile.to_csv(path, index=False)
    elif ext == '.parquet':
        fitfile.to_parquet(path, compression='brotli')
    else:
        raise NotImplementedError

    return None


def dtypes(fitfile: _pd.DataFrame, 
           category: bool = False) -> _pd.DataFrame:
    dtype_dict = {'folder': 'category', 'measurement': 'category', 'channel': 'category', 'polarity': 'category', 'signal': 'int'}

    columns = []

    for column, dtype in dtype_dict.items():
        if column in fitfile.columns:
            columns.append(column)

            fitfile = fitfile.astype({column : dtype})

    for column in fitfile.columns:
        if column not in columns:
            columns.append(column)

            if category:
                valueSet = set(fitfile[column])
                if (len(valueSet) == 1) & (len(fitfile) > 1):
                    fitfile = fitfile.astype({column : 'category'})

    fitfile = fitfile[columns]

    return fitfile


def paths(fitfile: _pd.DataFrame) -> list:
    paths = set()

    for folder in folders(fitfile):
        maskFolder = (fitfile['folder'] == folder)
        for measurement in measurements(fitfile[maskFolder]):
            path = _os.path.join(folder, measurement)
            paths.add(path)

    paths = list(paths)

    return paths


def folders(fitfile: _pd.DataFrame) -> list:
    folders = set(fitfile['folder'])

    folders = list(folders)

    return folders


def measurements(fitfile: _pd.DataFrame) -> list:
    measurements = set(fitfile['measurement'])

    measurements = list(measurements)

    return measurements


def channels(fitfile: _pd.DataFrame) -> list:
    channels = set(fitfile['channel'])

    channels = list(channels)

    return channels


def polarities(fitfile: _pd.DataFrame) -> list:
    polarities = set(fitfile['polarity'])

    polarities = list(polarities)

    return polarities


def masks(fitfile : _pd.DataFrame,
          iterables : list = None):
    for folder, mask_folder in (_masks_folder(fitfile) if 'folder' in iterables else [(None, _pd.Series(True, index=fitfile.index))]):
        for measurement, mask_measurement in (_masks_measurement(fitfile) if 'measurement' in iterables else [(None, _pd.Series(True, index=fitfile.index))]):
            for channel, mask_channel in (_masks_channel(fitfile) if 'channel' in iterables else [(None, _pd.Series(True, index=fitfile.index))]):
                for polarity, mask_polarity in (_masks_polarity(fitfile) if 'polarity' in iterables else [(None, _pd.Series(True, index=fitfile.index))]):
                    mask = mask_folder & mask_measurement & mask_channel & mask_polarity

                    if sum(mask) > 0:
                        yield (value for value, key in zip((folder, measurement, channel, polarity, mask), ('folder', 'measurement', 'channel', 'polarity', 'mask')) if key in iterables or key == 'mask')


def fitfiles(fitfile : _pd.DataFrame,
             iterables : list):
    for entry in iterables:
        if entry not in ['folder', 'measurement', 'channel', 'polarity']:
            raise ValueError(f'Entry {entry} is not a valid iterable.')

    for result in masks(fitfile, iterables):
        result = list(result)
        yield (*result[:-1], fitfile[result[-1]])


def batch(fitfile : _pd.DataFrame, 
          batch_size : int = 10000):
    while len(fitfile) > 0:
        size = _math.ceil(len(fitfile) / _math.ceil(len(fitfile) / batch_size))
        batch = fitfile[:size]
        fitfile = fitfile[size:]

        yield batch


def summary(fitfile : _pd.DataFrame) -> _pd.DataFrame:
    summary_dict = {}

    for channel, polarity, fitfile in fitfiles(fitfile, iterables=['channel', 'polarity']):

        if channel not in summary_dict:
            summary_dict[channel] = {}

        summary_dict[channel][polarity] = len(fitfile)

    summary = _pd.DataFrame(summary_dict)
    summary = summary.fillna(0)

    return summary


def _masks_folder(fitfile : _pd.DataFrame):
    for folder in folders(fitfile):
        mask = (fitfile['folder'] == folder)
        yield folder, mask


def _masks_measurement(fitfile : _pd.DataFrame):
    for measurement in measurements(fitfile):
        mask = (fitfile['measurement'] == measurement)
        yield measurement, mask


def _masks_channel(fitfile : _pd.DataFrame):
    for channel in channels(fitfile):
        mask = (fitfile['channel'] == channel)
        yield channel, mask


def _masks_polarity(fitfile : _pd.DataFrame):
    for polarity in polarities(fitfile):
        mask = (fitfile['polarity'] == polarity)
        yield polarity, mask