import logging
from typing import Callable, Dict, Type, Union

import pandas as pd
import timeout_decorator
from tqdm import tqdm
import xarray as xr


logger = logging.getLogger(__name__)


class EnhanceMethodBase:
    def enhance(self, df: pd.DataFrame) -> pd.DataFrame:
        """ Base method for enhancement to make sure they have a similar signature.
        
        Method should take a dataframe, enhance it with the dataset from
        class initialization, then return the dataset with the enhanced columns
        """
        raise NotImplementedError


class EnhanceGroupByApply(EnhanceMethodBase):
    """ This enhance class is used when windowed calculations should be done around points,
        when the dataset is inefficient to access pointwise, or when different datasets should
        be accessed based on the grouping.
        
    access_dataset_fn: 
    group_open_kwarg: 
    transform_dataset_fn: 
    """

    def __init__(
        self,
        access_dataset_fn: Callable[..., xr.Dataset],
        group_open_kwarg: Dict[str, Union[str, None]],
        transform_dataset_fn: Callable[[xr.Dataset], xr.Dataset],
    ):
        self.access_dataset_fn = access_dataset_fn
        self.group_open_kwarg = group_open_kwarg
        self.transform_dataset_fn = transform_dataset_fn

    def enhance(self, df: pd.DataFrame) -> pd.DataFrame:
        """ Uses open_kwarg keys names to group rows and 
        apply the access function, then the filtering function """
        group_keys = list(self.group_open_kwarg.keys())

        groups = []

        for group, group_df in df.groupby(group_keys):
            group_df = self._process_group_df(group, group_df, group_keys)
            groups.append(group_df)

        return pd.concat(groups)

    def _process_group_df(
        self, group: tuple, group_df: pd.DataFrame, group_keys: list
    ) -> pd.DataFrame:
        open_kwargs = {}

        for i, key in enumerate(group_keys):
            open_kwarg = self.group_open_kwarg[key]
            if open_kwarg:
                open_kwargs[open_kwarg] = group[i]

        try:
            group_ds = self.access_dataset_fn(**open_kwargs)
        except TypeError:
            raise TypeError(f"Tried to call {self.access_dataset_fn=} with {open_kwargs=}")
        except timeout_decorator.timeout_decorator.TimeoutError:
            logger.warning(f"Timeout Error with group {open_kwargs}")
            return group_df

        transformed_ds = self.transform_dataset_fn(group_ds)

        transform_kwargs = {}

        for key, value in open_kwargs.items():
            try:
                transformed_ds.coords[key].shape[0]
            except IndexError:
                pass
            else:
                transform_kwargs[key] = value

        refined_ds = transformed_ds.sel(**transform_kwargs, method="nearest")

        for var in refined_ds.variables:
            if var not in refined_ds.coords:
                values = refined_ds[var].values
                group_df[var] = values

        return group_df


class EnhancePointwise(EnhanceMethodBase):
    def __init__(
        self, ds: xr.Dataset, columns_to_coords: Dict[str, str], transform_ds_fn=None
    ):
        self.ds = ds
        self.columns_to_coords = columns_to_coords
        self.transform_ds_fn = transform_ds_fn

    def enhance(self, df: pd.DataFrame) -> pd.DataFrame:
        """ Uses pointwise selection which works really nicely with precomputed datasets """
        points = xr.Dataset(
            {value: df[key] for key, value in self.columns_to_coords.items()}
        )
        select_kwargs = {
            var: points[var] for var in points.variables if var not in points.coords
        }
        ds_pointwise = self.ds.sel(**select_kwargs, method="nearest")

        if self.transform_ds_fn:
            ds_pointwise = self.transform_ds_fn(ds_pointwise)

        df_pointwise = ds_pointwise.to_dataframe()[
            [var for var in ds_pointwise.variables if var not in ds_pointwise.coords]
        ]

        return pd.concat([df, df_pointwise], axis=1)
