"""
Download daily HYCOM data to use for forecasting.
"""

from datetime import datetime, time
import logging
from pathlib import Path
from tempfile import TemporaryDirectory

from dask.distributed import Client, LocalCluster, progress
from dask.diagnostics import ProgressBar
import xarray as xr

from dagster import (
    AssetMaterialization,
    daily_schedule,
    EventMetadata,
    Field,
    Output,
    OutputDefinition,
    pipeline,
    repository,
    RetryRequested,
    solid,
)

from facet_shared.resources import LocalFsMode, S3FsMode, local_preset, do_s3_preset

import hycom_common as common


@solid(
    config_schema={
        "date": Field(
            str,
            description="YYYY-MM-DD string to specify what HYCOM data should be downloaded",
        )
    },
    output_defs=[OutputDefinition(str, description="YYYY-MM-DD string from context")],
)
def input_date(context,) -> str:
    """ Return a a YYYY-MM-DD string from the context """

    date = context.solid_config["date"]

    context.log.info(f"Found and returning date '{date}' in solid config.")

    return date


class WritableLogger:
    msg = ""
    logger: logging.Logger

    def __init__(self, logger: logging.Logger, prefix: str = None):
        self.logger = logger
        self.prefix = prefix

    def write(self, msg: str):
        self.msg += msg

    def flush(self):
        if self.msg != "":
            if self.prefix:
                msg = self.prefix + " " + self.msg
            else:
                msg = self.msg
            self.logger.info(msg)
            self.msg = ""


@solid(
    output_defs=[
        OutputDefinition(
            str, "hycom_daily_netcdf_path", "Path to generated daily HYCOM NetCDF"
        )
    ],
    required_resource_keys={"fs"},
)
def hycom_ds_for_date(context, date: str) -> str:
    """ Download, transfrom, and process HYCOM data for a given date """
    netcdf_file = context.resources.fs.PutFile(f"daily/{date}.nc")

    with TemporaryDirectory() as temp_dir:

        with LocalCluster(
            n_workers=4, memory_limit="1.5GB", processes=True, threads_per_worker=2
        ) as cluster, Client(cluster) as client:

            try:
                ds = common.hycom_for_date(date, logger=context.log)
            except common.DAPServerError as e:
                raise RetryRequested(max_retries=3, seconds_to_wait=10 * 60)
            # except common.MissingDateError as e:

            context.log.info(
                "Saving temporary NetCDF to work from, as the HYCOM server limits the amount of simultaneous access."
            )

            temp_dir_path = Path(temp_dir)

            netcdfs = []

            try:
                for var in ds.variables:
                    if var not in ds.coords:
                        var_file = temp_dir_path / f"{var}.nc"
                        netcdfs.append(str(var_file))
                        context.log.info(
                            f"Saving variable ({var}) to NetCDF {var_file}: {ds[var]}"
                        )
                        var_netcdf = ds[var].to_netcdf(var_file, compute=False)

                        with ProgressBar(dt=10, out=WritableLogger(context.log, var)):
                            var_netcdf.compute()

                        context.log.info(f"Saved variable ({var}) to NetCDF {var_file}")

            except OSError as e:
                if "DAP server error" in str(e):
                    raise RetryRequested(max_retries=3, seconds_to_wait=30)
                else:
                    raise e

        context.log.info(f"Reopening from temporary NetCDFs ({', '.join(netcdfs)})")

        with LocalCluster(
            n_workers=1, memory_limit="7GB", processes=True, threads_per_worker=1
        ) as cluster, Client(
            cluster
        ) as client, netcdf_file as temp_path, xr.open_mfdataset(
            netcdfs, chunks={"lat": 500, "lon": 500}
        ) as ds:

            context.log.info(
                f"Reopened dataset from temporary files {ds}. Calculating derived data"
            )

            ds = common.hycom_vars(ds, logger=context.log)

            context.log.info(
                f"Derived data graph generated. Calculating and saving dataset to {temp_path}"
            )

            output_netcdf = ds.to_netcdf(temp_path, compute=False)

            with ProgressBar(dt=10, out=WritableLogger(context.log, "Computed output")):
                output_netcdf.compute()

            ds_summary = str(ds._repr_html_())

    yield AssetMaterialization(
        asset_key=["HYCOM", "Daily"],
        partition=date,
        metadata={
            **netcdf_file.metadata(),
            "dataset_summary": EventMetadata.md(ds_summary),
            "date": date,
        },
    )
    yield Output(netcdf_file.dest_path, "hycom_daily_netcdf_path")


local_fs_preset = local_preset("/data/dagster/hycom/")


@pipeline(
    mode_defs=[LocalFsMode, S3FsMode],
    preset_defs=[local_fs_preset, do_s3_preset("dagster-test/hycom")],
    tags={"hycom/download": "hycom/download"},
)
def daily_hycom_download():
    """ Download, process, and transfrom HYCOM data for a date """

    date = input_date()
    hycom_ds_for_date(date)


@daily_schedule(
    pipeline_name=daily_hycom_download.__name__,
    start_date=datetime(1992, 10, 2),
    execution_time=time(4, 0),
    execution_timezone="US/Eastern",
)
def daily_hycom_schedule(date):
    """ Download, process, and transform HYCOM data for the previous day """

    return {
        **local_fs_preset.run_config,
        "solids": {"input_date": {"config": {"date": date.strftime("%Y-%m-%d")}}},
    }


@repository()
def hycom_repository():
    return [daily_hycom_download, daily_hycom_schedule]
