"""
Dagster pipelines to take ICCAT tagging data, split it up by date,
then enhance it with data from HYCOM model data, 
before re-combining it into a single CSV.
"""

from dataclasses import dataclass
import itertools
from typing import List

from dagster import (
    AssetKey,
    AssetMaterialization,
    asset_sensor,
    EventMetadata,
    Field,
    InputDefinition,
    Noneable,
    Output,
    OutputDefinition,
    pipeline,
    repository,
    RunRequest,
    solid,
)
from dagster.experimental import DynamicOutput, DynamicOutputDefinition
import dask.dataframe as dd
import pandas as pd

from facet_shared.resources import LocalFsMode, S3FsMode, local_preset, do_s3_preset
import hycom_common as hycom


@solid(config_schema={"iccat_csv_path": str})
def iccat_path(context):
    """ Extract and return the path to the ICCAT CSV from the context """

    return context.solid_config["iccat_csv_path"]


@dataclass
class DatePath:
    date: str
    path: str


@solid(
    input_defs=[InputDefinition("iccat_csv_path", str, "Path to ICCAT csv")],
    config_schema={
        "limit_dates": Field(
            Noneable(int),
            default_value=None,
            is_required=False,
            description="Limit the number of dates returned to be processed",
        )
    },
    output_defs=[
        DynamicOutputDefinition(DatePath, "date_path", "Paths & dates to split up data")
    ],
    required_resource_keys={"fs"},
)
def split_by_date(context, iccat_csv_path: str):
    """ 
    Split ICCAT data by date for dates after 1992-10-02 (start of HYCOM data).
    
    'limit_dates' can be set in solid config to limit processing to the first X number of dates.
    """

    limit_dates = context.solid_config["limit_dates"]

    iccat_csv = context.resources.fs.GetFile(iccat_csv_path)

    with iccat_csv as temp_iccat_csv:
        df = pd.read_csv(temp_iccat_csv, parse_dates=["date"])

    # Filtering dates to match range of hycom

    df = df["1992-10-02" < df["date"]]

    ## end filtering!

    if limit_dates:
        context.log.info(f"Yielding {limit_dates}/{len(df)} dates")
    else:
        context.log.info(f"Yielding all {len(df)} dates")

    for date, date_df in itertools.islice(df.groupby(df["date"].dt.date), limit_dates):

        date_string = str(date)

        date_csv = context.resources.fs.PutFile(f"split/{date_string}.csv")

        with date_csv as temp_date_csv:
            date_df.to_csv(temp_date_csv, index=False)

        yield AssetMaterialization(
            asset_key=["ICCAT HYCOM", "Split by date"],
            partition=date_string,
            metadata={
                **date_csv.metadata(),
                "date": date_string,
                "rows": len(date_df),
                "First 5 rows": EventMetadata.md(date_df.head().to_markdown()),
                # "Column statistics": EventMetadata.md(date_df.describe().to_markdown()),
            },
        )
        yield DynamicOutput(
            DatePath(date_string, date_csv.dest_path),
            mapping_key=date_string.replace("-", "_"),
            output_name="date_path",
        )


@solid(
    required_resource_keys={"fs"},
    config_schema={
        "skip_cache": Field(
            bool,
            default_value=False,
            is_required=False,
            description="If True, download data for all dates, instead of using cache.",
        )
    },
)
def enhance_with_hycom(context, date_path: DatePath) -> DatePath:
    """ 
    Enhance a single date with HYCOM data 
    
    It will reuse existing data, rather than download again,
    but will check to see if the existing data is the same size (# or rows),
    and that the variables aren't NAs.

    If a date is missing in the HYCOM dataset, the variables are returned as NAs.
    """

    iccat_date_csv = context.resources.fs.GetFile(date_path.path)

    output_path = f"enhance/{date_path.date}.csv"

    with iccat_date_csv as temp_date_csv:
        iccat_df = pd.read_csv(temp_date_csv)

    skip_cache = context.solid_config["skip_cache"]

    if skip_cache:
        context.log.info(
            "`skip_cache` is set to True on the solid config, ignoring data in cache and regenerating all data"
        )
    else:

        try:
            possible_path = context.resources.fs.with_base_path(output_path)
            possible_existing_file = context.resources.fs.GetFile(possible_path)

            with possible_existing_file as temp_existing_csv:
                possible_df = pd.read_csv(temp_existing_csv)

            if len(possible_df) == len(iccat_df):
                try:
                    if possible_df[hycom.HYCOM_VARIABLES].isna().all().all():
                        context.log.warning(
                            f"{possible_path} already exists for {date_path.date}, but the existing HYCOM variables ({', '.join(hycom.HYCOM_VARIABLES)}) are all NA likely due to a missing day previously. Retrying."
                        )

                    else:
                        context.log.warning(
                            f"{possible_path} already exists for {date_path.date} and is the same length as {date_path.path}. Skipping enhancing with HYCOM data."
                        )

                        yield Output(DatePath(date_path.date, possible_path))
                        return
                except KeyError:
                    context.log.warning(
                        f"{possible_path} already exists for {date_path.date}, but is missing some required output columns, so updating."
                    )
            else:
                context.log.warning(
                    f"There is an existing file at {possible_path} for {date_path.date}, but it doesn't match the length of {date_path.path}, so updating."
                )

        except FileNotFoundError:
            context.log.info(
                f"There is no existing file on {date_path.date}, continuing to extract HYCOM data"
            )

    metadata_extra = {}

    try:
        hycom_ds = hycom.hycom_for_date(
            date_path.date, logger=context.log, quiet_logger=True, chunk=False
        )
        iccat_df = pd.concat(
            [iccat_df, hycom.extract_by_apply(hycom_ds, iccat_df, logger=context.log)],
            axis=1,
        )

    except hycom.MissingDateError:
        context.log.error(
            f"HYCOM data is missing for {date_path.date}. Replacing variables with NA."
        )

        for var in hycom.HYCOM_VARIABLES:
            iccat_df[var] = pd.NA

        metadata_extra[
            "Errors"
        ] = f"{date_path.date} is missing. Variables replaced with NA."
        metadata_extra["NA"] = "True"

    output_date_csv = context.resources.fs.PutFile(output_path)

    with output_date_csv as temp_output_csv:
        iccat_df.to_csv(temp_output_csv, index=False)

    yield AssetMaterialization(
        asset_key=["ICCAT HYCOM", "Enhance dates with HYCOM"],
        partition=date_path.date,
        metadata={
            **output_date_csv.metadata(),
            **metadata_extra,
            "date": date_path.date,
            "rows": len(iccat_df),
            "First 5 rows": EventMetadata.md(iccat_df.head().to_markdown()),
        },
    )
    yield Output(DatePath(date_path.date, output_date_csv.dest_path))


@solid(
    output_defs=[
        OutputDefinition(
            str,
            "iccat_hycom_csv_path",
            asset_key=AssetKey(["ICCAT_HYCOM", "Enhanced ICCAT data with HYCOM"]),
        )
    ],
    required_resource_keys={"fs"},
)
def combine_dates(context, enhanced_paths: List[DatePath]) -> str:
    """ Combine the individual date CSVs into a single CSV """
    df: pd.DataFrame = None

    for enhanced_date_path in enhanced_paths:

        enhanced_csv = context.resources.fs.GetFile(enhanced_date_path.path)
        with enhanced_csv as temp_csv:
            single_day = pd.read_csv(temp_csv)

        try:
            df = df.append(single_day)
        except AttributeError:
            df = single_day

    output_csv = context.resources.fs.PutFile("iccat_with_hycom.csv")

    with output_csv as temp_output_csv:
        df.to_csv(temp_output_csv, index=False)

    yield Output(
        output_csv.dest_path,
        "iccat_hycom_csv_path",
        metadata={
            **output_csv.metadata(),
            "CSV row count": len(df),
            "First 5 rows": EventMetadata.md(df.head().to_markdown()),
            "Column statistics": EventMetadata.md(df.describe().to_markdown()),
            "Date count": len(df["date"].unique()),
        },
    )


local_fs_preset = local_preset("/data/dagster/iccat_hycom/")


@pipeline(
    mode_defs=[LocalFsMode, S3FsMode],
    preset_defs=[local_fs_preset, do_s3_preset("dagster-test/iccat-hycom")],
    tags={"hycom/download": "hycom/download"},
)
def iccat_hycom():
    """
    Extract HYCOM data for source ICCAT csv.
    """

    csv_path = iccat_path()
    by_date = split_by_date(csv_path)
    hycom_dates = by_date.map(enhance_with_hycom)
    combine_dates(hycom_dates.collect())


@asset_sensor(
    asset_key=AssetKey(["ICCAT", "With pseudoabsences"]),
    pipeline_name=iccat_hycom.__name__,
)
def test_iccat_hycom_sensor(context, asset_event):
    path = ""

    for (
        metadata_entry
    ) in asset_event.dagster_event.event_specific_data.materialization.metadata_entries:
        if metadata_entry.label == "Path":
            path = metadata_entry.entry_data.path

    yield RunRequest(
        run_key=context.cursor,
        run_config={
            **local_fs_preset.run_config,
            "solids": {
                "iccat_path": {"config": {"iccat_csv_path": path}},
                "split_by_date": {"config": {"limit_dates": 150}},
            },
        },
        tags={
            "source_pipeline": asset_event.pipeline_name,
            "source_run_id": asset_event.run_id,
            "path": path,
        },
    )


@repository()
def iccat_hycom_repository():
    return [iccat_hycom, test_iccat_hycom_sensor]
