"""
Shared functions for loading, and transforming HYCOM data.

'logger' is an optional argument to some functions, 
so that Dagster's logger can be supplied
"""

import logging

import gsw
import numpy as np
import pandas as pd
import xarray as xr


class MissingDateError(KeyError):
    """ Thrown when there is a date missing in the source dataset """

    pass


class DAPServerError(OSError):
    """ Thrown when there is an issue with the DAP server """

    pass


SST = "sst"
SST_SD = "sst_sd"
SSH = "ssh"
SSH_SD = "ssh_sd"
SALINITY = "salinity"
WATER_U = "water_u"
WATER_V = "water_v"
EKE = "eke"
ILD = "ild"
N2 = "n2"

HYCOM_VARIABLES = [SST, SST_SD, SSH, SSH_SD, SALINITY, WATER_U, WATER_V, EKE, ILD, N2]

HYCOM_SOURCES = pd.DataFrame(
    {
        "start": pd.to_datetime(
            [
                "1992-10-02",
                "1995-08-01",
                "2013-01-01",
                "2013-08-20",
                "2014-04-05",
                "2016-04-18",
                "2018-12-04",
            ]
        ),
        "end": pd.to_datetime(
            [
                "1995-07-31",
                "2012-12-31",
                "2013-08-19",
                "2014-04-04",
                "2016-04-17",
                "2018-11-20",
                None,
            ]
        ),
        "url": [
            # "http://ncss.hycom.org/thredds/dodsC/GLBu0.08/expt_19.0",
            # "http://ncss.hycom.org/thredds/dodsC/GLBu0.08/expt_19.1",
            # "http://ncss.hycom.org/thredds/dodsC/GLBu0.08/expt_90.9",
            # "http://ncss.hycom.org/thredds/dodsC/GLBu0.08/expt_91.0",
            # "http://ncss.hycom.org/thredds/dodsC/GLBu0.08/expt_91.1",
            # "http://ncss.hycom.org/thredds/dodsC/GLBu0.08/expt_91.2",
            # "http://ncss.hycom.org/thredds/dodsC/GLBy0.08/expt_93.0",
            "http://tds.hycom.org/thredds/dodsC/GLBu0.08/expt_19.0",
            "http://tds.hycom.org/thredds/dodsC/GLBu0.08/expt_19.1",
            "http://tds.hycom.org/thredds/dodsC/GLBu0.08/expt_90.9",
            "http://tds.hycom.org/thredds/dodsC/GLBu0.08/expt_91.0",
            "http://tds.hycom.org/thredds/dodsC/GLBu0.08/expt_91.1",
            "http://tds.hycom.org/thredds/dodsC/GLBu0.08/expt_91.2",
            "http://tds.hycom.org/thredds/dodsC/GLBy0.08/expt_93.0",
        ],
        # components: `uv3z`, `ssh`, `ts3z`
        # "ftp": [
        #     "ftps://ftp.hycom.org/datasets/GLBy0.08/expt_93.0/data/forecasts/hycom_glby_930_{year}{month}{day}{hour}_t000_{component}.nc"
        # ],
    }
)


def hycom_url(date: str, logger: logging.Logger = None) -> str:
    """ Generate a valid URL for HYCOM data for the given date (YYYY-MM-DD) """
    if not logger:
        logger = logging.getLogger(__name__)

    dt = pd.to_datetime(date)

    start_urls = HYCOM_SOURCES[HYCOM_SOURCES["start"] < dt]

    if len(start_urls) > 1:
        filtered_urls = start_urls[dt <= start_urls["end"]]
    else:
        filtered_urls = start_urls

    try:
        row = filtered_urls.iloc[0]
    except IndexError:
        row = start_urls.iloc[-1]

    url = row["url"]

    logger.info(f"Found HYCOM url ({url}) for {date}")

    return url


def hycom_for_date(
    date: str,
    logger: logging.Logger = None,
    quiet_logger: bool = False,
    chunk: bool = True,
) -> xr.Dataset:
    """ Return an xarray Dataset with HYCOM for the given date (YYYY-MM-DD) """

    if not logger:
        logger = logging.getLogger(__name__)

    url = hycom_url(date, logger=logger)

    if chunk:
        chunks = {"lat": 500, "lon": 500, "depth": 1}
    else:
        chunks = None

    # Dropping tau since we should not need to know the hours since analysis,
    # and it interfears with datetime conversion
    try:
        ds = xr.open_dataset(
            url,
            drop_variables=[
                "tau",
                "water_temp_bottom",
                "salinity_bottom",
                "water_u_bottom",
                "water_v_bottom",
            ],
            chunks=chunks,
        )
    except OSError as e:
        if "DAP server error" in str(e):
            logger.error(f"Error accessing DAP server for {date}")
            raise DAPServerError(f"Error accessing DAP server for {date}") from e

        if "NetCDF: I/O failure" in str(e):
            logger.error(f"NetCDF I/O failure for {date}")
            raise DAPServerError(f"NetCDF I/O failure for {date}") from e

        raise e
    try:
        ds = ds.sel(time=date)
    except KeyError as e:
        logger.error(
            f"Error selecting date ({date}) from ds: \n {e} \n {ds} \n {ds['time']}"
        )
        raise MissingDateError(f"Date {date} missing in source dataset") from e

    try:
        if 1 < len(ds["time"]):
            date_with_time = f"{date}T00:00:00"
            ds = ds.sel(time=date_with_time)

            logger.info(
                f"There were multiple values ({len(ds['time'])}) for 'time', so selected {date_with_time}"
            )
    except TypeError:
        pass

    # Some (newer) runs have longitude in 0 - 360, while some (older) are -180 - 180
    if 180 < ds["lon"].max():
        logger.info("Longitude was 0-360, converting to -180 to 180")
        ds = ds.assign_coords(lon=(((ds["lon"] + 180) % 360) - 180))

        ds = ds.sortby(["lat", "lon", "depth"])

    ds["water_u"] = ds["water_u"].sel(depth=0)
    ds["water_v"] = ds["water_v"].sel(depth=0)

    if not quiet_logger:
        logger.info(f"Opened and filtered dataset for date {date}: {ds}")

    return ds


def std(
    da: xr.DataArray, window: int = 3, logger: logging.Logger = None
) -> xr.DataArray:
    """ Calculate the standard deviation of a given rolling window """
    r = da.rolling(lat=window, lon=window, center=True)
    std_da = r.std()

    std_da.attrs["long_name"] = "Standard deviation of " + da.attrs["long_name"]
    std_da.attrs["standard_name"] = da.attrs["standard_name"] + "_stdev"

    return std_da


def eke(ds: xr.Dataset, logger: logging.Logger = None) -> xr.DataArray:
    """ Calculate the EKE for the given dataset (output from hycom_for_date) """
    try:
        water_u = ds["water_u"].sel(depth=0)
    except (ValueError, KeyError):
        water_u = ds["water_u"]
    try:
        water_v = ds["water_v"].sel(depth=0)
    except (ValueError, KeyError):
        water_v = ds["water_v"]

    da = (water_u ** 2 + water_v ** 2) / 2

    da.attrs["long_name"] = "Eddy Kinetic Energy"
    da.attrs["standard_name"] = "EKE"

    return da


def ild(
    ds: xr.Dataset, delta_t: float = 0.5, logger: logging.Logger = None
) -> xr.DataArray:
    """ Calculate the ILD for the given dataset (output from hycom_for_date) """
    water_temp = ds["water_temp"]
    dr_surface = water_temp.sel(depth=0) - delta_t
    dr_delta = np.abs(water_temp - dr_surface)
    da = dr_delta.idxmin(dim="depth", skipna=True)

    da.attrs["long_name"] = "Isothermal layer depth for 0.5deg C"
    da.attrs["units"] = "m"
    da.attrs["add_offset"] = 0.0
    da.attrs["scale_factor"] = 0.0
    da.attrs["missing_value"] = np.NaN

    return da


def n2_transform_raw_variables(ds):
    """ Coerce raw variables to n2 input types (from Alex K function) """
    z = -1 * ds["depth"]
    p = gsw.p_from_z(z, ds["lat"])
    SA = gsw.SA_from_SP(ds["salinity"], p, ds["lon"], ds["lat"])
    CT = gsw.CT_from_t(SA, ds["water_temp"], p)
    return z, p, SA, CT


def n2_build_inputs(ds):
    """ Flatten n2 inputs from 3D to 2D """
    z, p, SA, CT = n2_transform_raw_variables(ds)
    n_lon, n_lat = CT.shape[2], CT.shape[1]
    " reshape data for gsw N2 function  "
    CT = CT.data.transpose(1, 2, 0).reshape(n_lon * n_lat, len(z))
    SA = SA.data.transpose(1, 2, 0).reshape(n_lon * n_lat, len(z))

    lat = np.tile(ds["lat"].values, (n_lon, 1)).T.ravel().reshape(n_lon * n_lat, 1)
    p = np.repeat(p.values.T, repeats=n_lon, axis=0)
    return SA, CT, p, lat


def map_n2(block: xr.Dataset, *args, logger: logging.Logger, **kwargs) -> xr.DataArray:
    """ Map N2 calculations to a single input block """
    SA, CT, p, lat = n2_build_inputs(block)

    n2_flat = gsw.stability.Nsquared(SA, CT, p, lat=lat, axis=1)[0]

    lat_ = block["lat"].values
    lon_ = block["lon"].values
    midpoints = (block["depth"].values[1:] + block["depth"].values[:-1]) / 2

    n2_grid = n2_flat.reshape(len(lat_), len(lon_), len(midpoints)).transpose(2, 0, 1)

    da = xr.DataArray(
        n2_grid, coords=[midpoints, lat_, lon_], dims=["depth", "lat", "lon"]
    )
    # select depths between 0 and 200 meters
    da = da.sel(depth=slice(0, 200))
    da = da.mean("depth")
    return da


def n2(ds: xr.Dataset, logger: logging.Logger = None) -> xr.DataArray:
    """ Calculate N2 (buoyancy frequency) using map_blocks to process smaller chunks at a time"""

    # create a template for the results of map_blocks (xarray/dask needs to predict the return shape)
    template = ds["water_temp"].isel(depth=0)

    # drop the time coordinate from the template
    template = template.drop(["time", "depth"])

    # map calculation over individual blocks
    da = ds.map_blocks(map_n2, template=template, kwargs={"logger": logger})

    da.attrs["long_name"] = "Brunt-Vaisala Frequency"
    da.attrs["units"] = "Square of buoyancy frequency [radian^2/s^2]"

    return da


def hycom_vars(
    ds: xr.Dataset, logger: logging.Logger = None, quiet_logger: bool = False
) -> xr.Dataset:
    """ Transform a HYCOM dataset via OpenDAP into our preferred format """
    if not logger:
        logger = logging.getLogger(__name__)

    try:
        sst = ds["water_temp"].sel(depth=0)
    except (ValueError, KeyError):
        sst = ds["water_temp"]

    sst_sd = std(sst, logger=logger)

    ssh = ds["surf_el"]
    ssh_sd = std(ssh, logger=logger)

    try:
        salinity = ds["salinity"].sel(depth=0)
    except (ValueError, KeyError):
        salinity = ds["salinity"]

    try:
        water_u = ds["water_u"].sel(depth=0)
    except (ValueError, KeyError):
        water_u = ds["water_u"]
    try:
        water_v = ds["water_v"].sel(depth=0)
    except (ValueError, KeyError):
        water_v = ds["water_v"]

    eke_da = eke(ds, logger=logger)
    ild_da = ild(ds, logger=logger)
    n2_da = n2(ds, logger=logger)

    transformed = xr.Dataset(
        {
            # dataset variable names as 'constants' as they are used elsewhere
            # for comparison
            SST: sst,
            SST_SD: sst_sd,
            SSH: ssh,
            SSH_SD: ssh_sd,
            SALINITY: salinity,
            WATER_U: water_u,
            WATER_V: water_v,
            EKE: eke_da,
            ILD: ild_da,
            N2: n2_da,
        }
    )

    if not quiet_logger:
        logger.info(f"Transformed dataset into: {transformed}")

    return transformed


def extract_by_diagonal(ds: xr.Dataset, df: pd.DataFrame) -> pd.DataFrame:
    """ 
    A less effective process for extracting data, 
    as it collects the product of the lat & lons, 
    rather than the pairs.

    It may be useful to hand to NOAA for their data extraction. 
    """
    subset = ds.sel(lat=df["lat"], lon=df["lon"], method="nearest")

    results = {}

    for var in subset.variables:
        if var not in subset.coords:
            var_values = subset[var].data.diagonal()
            results[var] = var_values

    return pd.DataFrame(results)


def extract_row(row: pd.DataFrame, ds: xr.Dataset, logger: logging.Logger = None, lat_col: str = "lat", lon_col: str = "lon"):
    """ 
    Extract a single row of a dataframe from an xarray dataset.

    Longitudes are expected to be from -180 - 180 in a 'lon' column.
    Latitudes in a 'lat' column.
    """
    lat = row[lat_col]
    lat_slice = slice(lat - 1, lat + 1)
    # lon = row["lon"] % 360
    lon = row[lon_col]
    lon_slice = slice(lon - 1, lon + 1)

    try:
        # slice a window to calculate rolling from
        subset = ds.sel(lat=lat_slice, lon=lon_slice)
    except KeyError as e:
        raise Exception(f"{row=}, {ds=}, {lat_slice=}, {lon_slice=}, {e=}")

    try:
        subset_vars = hycom_vars(subset, logger, quiet_logger=True)
    except ValueError as e:
        raise Exception(f"{row=}, {ds=}, {subset=} {lat_slice=}, {lon_slice=}, {e=}")
    subset_vars = subset_vars.sel(lat=lat, lon=lon, method="nearest")

    results = {}

    for var in subset_vars.variables:
        if var not in subset_vars.coords:
            results[var] = subset_vars[var].values

    return results


def extract_by_apply( 
     ds: xr.Dataset, df: pd.DataFrame, logger: logging.Logger = None, lat_col: str = "lat", lon_col: str = "lon" 
) -> pd.DataFrame:
    """ 
    Extract data from an xarray dataset for each row in a pandas table 
    
    Longitudes are expected to be from -180 - 180 in a 'lon' column.
    Latitudes in a 'lat' column.
    """

    return df.apply(
         lambda row: extract_row(row, ds, logger, lat_col=lat_col, lon_col=lon_col), axis=1, result_type="expand"
        )

