"""
Shared functions for loading, and transforming HYCOM data.

'logger' is an optional argument to some functions, 
so that Dagster's logger can be supplied
"""

import logging

import gsw
import numpy as np
import pandas as pd
import xarray as xr
import bs4
import lxml

class MissingDateError(KeyError):
    """ Thrown when there is a date missing in the source dataset """

    pass


class DAPServerError(OSError):
    """ Thrown when there is an issue with the DAP server """

    pass


SST = "sst"
SST_SD = "sst_sd"
SSH = "ssh"
SSH_SD = "ssh_sd"
SALINITY = "salinity"
WATER_U = "water_u"
WATER_V = "water_v"
EKE = "eke"
MLD = "mld"
ILD = "ild"
MLD = "mld"
N2 = "n2"

CMEMS_VARIABLES = [SST, SST_SD, SSH, SSH_SD, SALINITY, WATER_U, WATER_V, EKE, MLD, ILD, MLD, N2]

## from https://help.marine.copernicus.eu/en/articles/5182598-how-to-consume-the-opendap-api-and-cas-sso-using-python
#@lru_cache
def copernicusmarine_datastore(dataset, username, password):
    from pydap.client import open_url
    from pydap.cas.get_cookies import setup_session
    cas_url = 'https://cmems-cas.cls.fr/cas/login'
    session = setup_session(cas_url, username, password)
    session.cookies.set("CASTGC", session.cookies.get_dict()['CASTGC'])
    database = ['my', 'nrt']
    url = f'https://{database[0]}.cmems-du.eu/thredds/dodsC/{dataset}'
    try:
        data_store = xr.backends.PydapDataStore(open_url(url, session=session))
    except:
        url = f'https://{database[1]}.cmems-du.eu/thredds/dodsC/{dataset}'
        data_store = xr.backends.PydapDataStore(open_url(url, session=session))
    return data_store


def cmems_for_date(
    date: str,
    logger: logging.Logger = None,
    quiet_logger: bool = False,
    chunk: bool = True,
) -> xr.Dataset:
    """ Return an xarray Dataset with CMEMS for the given date (YYYY-MM-DD) """

    if not logger:
        logger = logging.getLogger(__name__)

    #COPERNICUS_USER = os.environ['COPERNICUS_USER']
    #COPERNICUS_PASSWORD = os.environ['COPERNICUS_PASSWORD']
    
    data_store = copernicusmarine_datastore(
        'global-reanalysis-phy-001-031-grepv2-mnstd-daily', 
        'cbraun', 
        'A37x&sk0a6Gq'
        #COPERNICUS_USER, 
        #COPERNICUS_PASSWORD
        )
    #url = hycom_url(date, logger=logger)

    if chunk:
        chunks = {"latitude": 500, "longitude": 500, "depth": 1}
    else:
        chunks = None

    # Dropping tau since we should not need to know the hours since analysis,
    # and it interfears with datetime conversion
    try:
        ds = xr.open_dataset(
            data_store,
            drop_variables=[
                "thetao_std",
                "so_std",
                "uo_std",
                "vo_std",
                "zos_std",
                "mlotst_std",
            ],
            chunks=chunks,
        )
    except OSError as e:
        if "DAP server error" in str(e):
            logger.error(f"Error accessing DAP server for {date}")
            raise DAPServerError(f"Error accessing DAP server for {date}") from e

        if "NetCDF: I/O failure" in str(e):
            logger.error(f"NetCDF I/O failure for {date}")
            raise DAPServerError(f"NetCDF I/O failure for {date}") from e

        raise e
    try:
        ds = ds.sel(time=date)
    except KeyError as e:
        logger.error(
            f"Error selecting date ({date}) from ds: \n {e} \n {ds} \n {ds['time']}"
        )
        raise MissingDateError(f"Date {date} missing in source dataset") from e

    #try:
    #    if 1 < len(ds["time"]):
    #        date_with_time = f"{date}T00:00:00"
    #        ds = ds.sel(time=date_with_time)
    #        logger.info(
    #            f"There were multiple values ({len(ds['time'])}) for 'time', so selected {date_with_time}"
    #        )
    #except TypeError:
    #    pass

    # Some (newer) runs have longitude in 0 - 360, while some (older) are -180 - 180
    #if 180 < ds["longitude"].max():
    #    logger.info("Longitude was 0-360, converting to -180 to 180")
    #    ds = ds.assign_coords(lon=(((ds["longitude"] + 180) % 360) - 180))
    #    ds = ds.sortby(["lat", "longitude", "depth"])

    ds["uo_mean"] = ds["uo_mean"].sel(depth=0, method="nearest")
    ds["vo_mean"] = ds["vo_mean"].sel(depth=0, method="nearest")

    if not quiet_logger:
        logger.info(f"Opened and filtered dataset for date {date}: {ds}")

    return ds


def std(
    da: xr.DataArray, window: int = 3, logger: logging.Logger = None
) -> xr.DataArray:
    """ Calculate the standard deviation of a given rolling window """
    r = da.rolling(latitude=window, longitude=window, center=True)
    std_da = r.std()

    std_da.attrs["long_name"] = "Standard deviation of " + da.attrs["long_name"]
    std_da.attrs["standard_name"] = da.attrs["standard_name"] + "_stdev"

    return std_da


def eke(water_u: xr.DataArray, water_v: xr.DataArray) -> xr.DataArray:
    """ Calculate the EKE for the given water v and u DataArrays """
    da = (water_u ** 2 + water_v ** 2) / 2

    da.attrs["long_name"] = "Eddy Kinetic Energy"
    da.attrs["standard_name"] = "EKE"

    return da
    
def cmems_eke(ds: xr.Dataset, logger: logging.Logger = None) -> xr.DataArray:
    """ Calculate the EKE for the given dataset (output from hycom_for_date) """
    try:
        water_u = ds["uo_mean"].sel(depth=0, method="nearest")
    except (ValueError, KeyError):
        water_u = ds["uo_mean"]
    try:
        water_v = ds["vo_mean"].sel(depth=0, method="nearest")
    except (ValueError, KeyError):
        water_v = ds["vo_mean"]

    return eke(water_v, water_u)


def ild(
    ds: xr.Dataset, delta_t: float = 0.5, logger: logging.Logger = None
) -> xr.DataArray:
    """ Calculate the ILD for the given dataset (output from hycom_for_date) """
    water_temp = ds["thetao_mean"]
    dr_surface = water_temp.sel(depth=0, method="nearest") - delta_t
    dr_delta = np.abs(water_temp - dr_surface)
    da = dr_delta.idxmin(dim="depth", skipna=True)

    da.attrs["long_name"] = "Isothermal layer depth for 0.5deg C"
    da.attrs["units"] = "m"
    da.attrs["add_offset"] = 0.0
    da.attrs["scale_factor"] = 0.0
    da.attrs["missing_value"] = np.NaN

    return da


def n2_transform_raw_variables(ds):
    """ Coerce raw variables to n2 input types (from Alex K function) """
    z = -1 * ds["depth"]
    p = gsw.p_from_z(z, ds["latitude"])
    SA = gsw.SA_from_SP(ds["so_mean"], p, ds["longitude"], ds["latitude"])
    CT = gsw.CT_from_t(SA, ds["thetao_mean"], p)
    return z, p, SA, CT


def n2_build_inputs(ds):
    """ Flatten n2 inputs from 3D to 2D """
    z, p, SA, CT = n2_transform_raw_variables(ds)
    n_lon, n_lat = CT.shape[2], CT.shape[1]
    " reshape data for gsw N2 function  "
    CT = CT.data.transpose(1, 2, 0).reshape(n_lon * n_lat, len(z))
    SA = SA.data.transpose(1, 2, 0).reshape(n_lon * n_lat, len(z))

    lat = np.tile(ds["latitude"].values, (n_lon, 1)).T.ravel().reshape(n_lon * n_lat, 1)
    p = np.repeat(p.values.T, repeats=n_lon, axis=0)
    return SA, CT, p, lat


def map_n2(block: xr.Dataset, *args, logger: logging.Logger, **kwargs) -> xr.DataArray:
    """ Map N2 calculations to a single input block """
    SA, CT, p, lat = n2_build_inputs(block)

    n2_flat = gsw.stability.Nsquared(SA, CT, p, lat=lat, axis=1)[0]

    lat_ = block["latitude"].values
    lon_ = block["longitude"].values
    midpoints = (block["depth"].values[1:] + block["depth"].values[:-1]) / 2

    n2_grid = n2_flat.reshape(len(lat_), len(lon_), len(midpoints)).transpose(2, 0, 1)

    da = xr.DataArray(
        n2_grid, coords=[midpoints, lat_, lon_], dims=["depth", "latitude", "longitude"]
    )
    # select depths between 0 and 200 meters
    da = da.sel(depth=slice(0, 200))
    da = da.mean("depth")
    return da


def n2(ds: xr.Dataset, logger: logging.Logger = None) -> xr.DataArray:
    """ Calculate N2 (buoyancy frequency) using map_blocks to process smaller chunks at a time"""

    # create a template for the results of map_blocks (xarray/dask needs to predict the return shape)
    template = ds["thetao_mean"].sel(depth=0, method="nearest")

    # drop the time coordinate from the template
    template = template.drop(["time", "depth"])

    # map calculation over individual blocks
    da = ds.map_blocks(map_n2, template=template, kwargs={"logger": logger})

    da.attrs["long_name"] = "Brunt-Vaisala Frequency"
    da.attrs["units"] = "Square of buoyancy frequency [radian^2/s^2]"

    return da


def cmems_vars(
    ds: xr.Dataset, logger: logging.Logger = None, quiet_logger: bool = False
) -> xr.Dataset:
    """ Transform a HYCOM dataset via OpenDAP into our preferred format """
    if not logger:
        logger = logging.getLogger(__name__)

    try:
        sst = ds["thetao_mean"].sel(depth=0, method="nearest")
    except (ValueError, KeyError):
        sst = ds["thetao_mean"]

    sst_sd = std(sst, logger=logger)

    ssh = ds["zos_mean"]
    ssh_sd = std(ssh, logger=logger)

    try:
        salinity = ds["so_mean"].sel(depth=0, method="nearest")
    except (ValueError, KeyError):
        salinity = ds["so_mean"]

    try:
        water_u = ds["uo_mean"].sel(depth=0, method="nearest")
    except (ValueError, KeyError):
        water_u = ds["uo_mean"]
    try:
        water_v = ds["vo_mean"].sel(depth=0, method="nearest")
    except (ValueError, KeyError):
        water_v = ds["vo_mean"]

    try:
        mld = ds["mlotst_mean"].sel(depth=0, method="nearest")
    except (ValueError, KeyError):
        mld = ds["mlotst_mean"]

    eke_da = eke(ds, logger=logger)
    ild_da = ild(ds, logger=logger)
    n2_da = n2(ds, logger=logger)

    transformed = xr.Dataset(
        {
            # dataset variable names as 'constants' as they are used elsewhere
            # for comparison
            SST: sst,
            SST_SD: sst_sd,
            SSH: ssh,
            SSH_SD: ssh_sd,
            SALINITY: salinity,
            WATER_U: water_u,
            WATER_V: water_v,
            EKE: eke_da,
            MLD: mld,
            ILD: ild_da,
            N2: n2_da,
        }
    )

    if not quiet_logger:
        logger.info(f"Transformed dataset into: {transformed}")

    return transformed


