# LICENSE HEADER MANAGED BY add-license-header
# Copyright (c) 2025 Shengyu Kang (Wuhan University)
# Licensed under the Apache License, Version 2.0
# http://www.apache.org/licenses/LICENSE-2.0
#

"""
Base module for CaMa-Flood-GPU using the TensorField / computed_tensor_field
helpers for concise tensor metadata.
"""
from __future__ import annotations

from functools import cached_property
from typing import ClassVar, List, Literal, Optional, Self, Tuple

import torch
from pydantic import Field, computed_field, model_validator

from cmfgpu.modules.abstract_module import (AbstractModule, TensorField,
                                            computed_tensor_field)
from cmfgpu.utils import find_indices_in_torch


def BaseField(
    description: str,
    shape: Tuple[str, ...] = ("num_catchments",),
    dtype: Literal["float", "int", "bool"] = "float",
    group_by: Optional[str] = "catchment_basin_id",
    save_idx: Optional[str] = "catchment_save_idx",
    save_coord: Optional[str] = "catchment_save_id",
    dim_coords: Optional[str] = "catchment_id",
    category: Literal["topology", "param", "init_state"] = "param",
    **kwargs
):

    return TensorField(
        description=description,
        shape=shape,
        dtype=dtype,
        group_by=group_by,
        save_idx=save_idx,
        save_coord=save_coord,
        dim_coords=dim_coords,
        category=category,
        **kwargs
    )

def computed_base_field(
    description: str,
    shape: Tuple[str, ...] = ("num_catchments",),
    dtype: Literal["float", "int", "bool"] = "float",
    save_idx: Optional[str] = "catchment_save_idx",
    save_coord: Optional[str] = "catchment_save_id",
    dim_coords: Optional[str] = "catchment_id",
    category: Literal["topology", "derived_param", "state"] = "derived_param",
    **kwargs
):

    return computed_tensor_field(
        description=description,
        shape=shape,
        dtype=dtype,
        save_idx=save_idx,
        save_coord=save_coord,
        dim_coords=dim_coords,
        category=category,
        **kwargs
    )

class BaseModule(AbstractModule):
    # --------------------------------------------------------------------- #
    # Metadata
    # --------------------------------------------------------------------- #
    module_name: ClassVar[str] = "base"
    description: ClassVar[str] = (
        "Core hydrodynamic module with fundamental river and catchment variables"
    )
    dependencies: ClassVar[List[str]] = []

    # --------------------------------------------------------------------- #
    # Scalars (dimensions & constants)
    # --------------------------------------------------------------------- #
    gravity: float = Field(
        default=9.8,
        description="Gravitational acceleration constant (m/s²)",
        gt=0.0,
    )

    # --------------------------------------------------------------------- #
    # Network topology
    # --------------------------------------------------------------------- #
    catchment_id: torch.Tensor = BaseField(
        description=(
            "Unique ID of each catchment (e.g., generated by "
            "catchment_id = np.ravel_multi_index((catchment_x, catchment_y), self.map_shape))"
        ),
        dtype="int",
        category="topology",
    )


    downstream_id: torch.Tensor = BaseField(
        description="ID of immediate downstream catchment (points to self at river mouth)",
        dtype="int",
        category="topology",
    )

    # --------------------------------------------------------------------- #
    # River-channel geometry
    # --------------------------------------------------------------------- #
    river_width: torch.Tensor = BaseField(
        description="River-channel width (m)",
        category="param",
        shape=("num_catchments",),
    )

    river_length: torch.Tensor = BaseField(
        description="River-channel length (m)",
        category="param",
        shape=("num_catchments",),
    )

    river_height: torch.Tensor = BaseField(
        description="Bankfull depth of river channel (m)",
        category="param",
        shape=("num_catchments",),
    )

    # --------------------------------------------------------------------- #
    # Catchment properties
    # --------------------------------------------------------------------- #
    catchment_elevation: torch.Tensor = BaseField(
        description="Mean ground elevation (m a.s.l.)",
        category="param",
    )

    catchment_area: torch.Tensor = BaseField(
        description="Surface area of catchment (m²)",
        category="param",
    )

    downstream_distance: torch.Tensor = BaseField(
        description="Downstream distance (m)",
        category="param",
    )

    # Catchment-type flags
    is_river_mouth: torch.Tensor = BaseField(
        description="Boolean mask for river-mouth catchments",
        dtype="bool",
        category="topology",
    )

    levee_catchment_id: Optional[torch.Tensor] = BaseField(
        description="Catchment ID for each levee",
        dtype="int",
        group_by="levee_basin_id",
        dim_coords=None,
        shape=("num_levees",),
        default=None,
        category="topology",
    )

    # Output-control mask
    catchment_save_mask: Optional[torch.Tensor] = BaseField(
        description="Boolean mask of catchments for which output will be saved",
        dtype="bool",
        default=None,
        category="topology",
    )

    # --------------------------------------------------------------------- #
    # hydrodynamic parameters
    # --------------------------------------------------------------------- #
    river_manning: torch.Tensor = BaseField(
        description="Manning roughness for rivers (-)",
        default=0.03,
        category="param",
    )

    flood_manning: torch.Tensor = BaseField(
        description="Manning roughness for floodplains (-)",
        default=0.1,
        category="param",
    )

    # --------------------------------------------------------------------- #
    # Lookup tables (dependent on num_flood_levels)
    # --------------------------------------------------------------------- #
    flood_depth_table: torch.Tensor = BaseField(
        description="Lookup table: flood depth vs. fraction of catchment area flooded (m)",
        shape=("num_catchments", "num_flood_levels"),
        category="param",
    )

    # --------------------------------------------------------------------- #
    # State variables (initialised to 0 where not supplied)
    # --------------------------------------------------------------------- #
    river_storage: torch.Tensor = BaseField(
        description="Current water volume in river channels, including any above bankfull depth (m³).",
        default=0,
        category="init_state",
    )

    flood_storage: torch.Tensor = BaseField(
        description="Current water volume stored on floodplains (m³)",
        default=0,
        category="init_state",
    )

    protected_storage: torch.Tensor = BaseField(
        description="Current water volume stored in protected areas (m³)",
        default=0,
        category="init_state",
    )

    protected_depth: torch.Tensor = BaseField(
        description="Current water depth on the protected side relative to river bed (m)",
        default=0,
        category="init_state",
    )
    
    river_depth: torch.Tensor = BaseField(
        description="Current water depth in rivers (m)",
        default=0,
        category="init_state",
    )

    flood_depth: torch.Tensor = BaseField(
        description="Current water depth on floodplains above river bankfull (m)",
        default=0,
        category="init_state",
    )

    river_outflow: torch.Tensor = BaseField(
        description="Volumetric flow rate out of rivers (m³ s⁻¹)",
        default=0,
        category="init_state",
    )

    flood_outflow: torch.Tensor = BaseField(
        description="Volumetric flow rate out of floodplains (m³ s⁻¹)",
        default=0,
        category="init_state",
    )

    river_cross_section_depth: torch.Tensor = BaseField(
        description="Effective water depth used in river-flow calculations (m)",
        default=0,
        category="init_state",
    )

    flood_cross_section_depth: torch.Tensor = BaseField(
        description="Effective water depth used in flood-flow calculations (m)",
        default=0,
        category="init_state",
    )

    flood_cross_section_area: torch.Tensor = BaseField(
        description="Cross-sectional flow area on floodplains (m²)",
        default=0,
        category="init_state",
    )

    # ------------------------------------------------------------------ #
    # Computed scalar dimensions
    # ------------------------------------------------------------------ #
    @computed_field(
        description="Total number of catchments."
    )
    @cached_property
    def num_catchments(self) -> int:
        return self.catchment_area.shape[0]

    @computed_field(
        description="Number of catchments that will have their output saved."
    )
    @cached_property
    def num_saved_catchments(self) -> int:
        return len(self.catchment_save_idx)
    
    @computed_field(
        description="Number of flood levels represented in the lookup tables."
    )
    @cached_property
    def num_flood_levels(self) -> int:
        return self.flood_depth_table.shape[1]

    @computed_field(description="Total number of levees")
    @cached_property
    def num_levees(self) -> int:
        if self.levee_catchment_id is None:
            return 0
        return self.levee_catchment_id.shape[0]

    # ------------------------------------------------------------------ #
    # Computed tensor fields
    # ------------------------------------------------------------------ #
    @computed_base_field(
        description="Indices of immediate downstream catchments",
        dtype="int",
        category="topology",
    )
    @cached_property
    def downstream_idx(self) -> torch.Tensor:
        return find_indices_in_torch(self.downstream_id, self.catchment_id)

    @computed_base_field(
        description="Indices of catchments for which output will be saved",
        shape=("num_saved_catchments",),
        dtype="int",
        category="topology",
    )
    @cached_property
    def catchment_save_idx(self) -> torch.Tensor:
        if self.catchment_save_mask is None:
            return torch.arange(self.num_catchments, dtype=torch.int64, device=self.device)
        catchment_save_idx = torch.nonzero(self.catchment_save_mask, as_tuple=False).squeeze(-1).to(self.device)
        if catchment_save_idx.numel() == 0:
            return None
        else:
            return catchment_save_idx
        
    @computed_base_field(
        description="Catchment IDs for which output will be saved",
        shape=("num_saved_catchments",),
        dtype="int",
        category="topology",
    )
    @cached_property
    def catchment_save_id(self) -> torch.Tensor:
        """
        Returns the catchment IDs for which output will be saved.
        """
        if self.catchment_save_mask is None:
            return self.catchment_id
        return self.catchment_id[self.catchment_save_idx]
    
    @computed_base_field(
        description="Boolean mask for catchments governed by levee physics",
        dtype="bool",
        category="topology",
    )
    @cached_property
    def is_levee(self) -> torch.Tensor:
        if "levee" not in self.opened_modules or self.levee_catchment_id is None:
            return torch.zeros(self.num_catchments, dtype=torch.bool, device=self.device)
        
        indices = find_indices_in_torch(self.levee_catchment_id, self.catchment_id)
        valid_mask = indices >= 0
        valid_indices = indices[valid_mask]
        
        mask = torch.zeros(self.num_catchments, dtype=torch.bool, device=self.device)
        mask[valid_indices] = True
        return mask
    
    @computed_base_field(
        description="Total water storage per catchment (m³)",
        category="state",
    )
    @cached_property
    def total_storage(self) -> torch.Tensor:
        return self.river_storage + self.flood_storage

    # ---------------- Hidden / intermediate states ------------------- #
    @computed_base_field(
        description="Total outflow via all bifurcation paths (m³ s⁻¹)",
        category="state",
    )
    @cached_property
    def global_bifurcation_outflow(self) -> torch.Tensor:
        return torch.zeros_like(self.river_outflow)

    @computed_base_field(
        description="Levee surface elevation (m a.s.l.)",
        category="state",
    )
    @cached_property
    def global_levee_surface_elevation(self) -> torch.Tensor:
        return torch.zeros_like(self.river_outflow)
    
    @computed_base_field(
        description=("Total outgoing storage from each catchment (m³)"
                     "Can not be saved, as it is a temporary state."),
        save_idx=None,
        category="state",
    )
    @cached_property
    def outgoing_storage(self) -> torch.Tensor:
        return torch.zeros_like(self.river_outflow)

    @computed_base_field(
        description="Water-surface elevation (m a.s.l.)",
        category="state",
    )
    @cached_property
    def water_surface_elevation(self) -> torch.Tensor:
        return torch.zeros_like(self.river_outflow)

    @computed_base_field(
        description="Protected water-surface elevation (m a.s.l.)",
        category="state",
    )
    @cached_property
    def protected_water_surface_elevation(self) -> torch.Tensor:
        return torch.zeros_like(self.river_outflow)

    @computed_base_field(
        description="Maximum flow-rate limit per catchment (m³ s⁻¹)",
        category="state",
    )
    @cached_property
    def limit_rate(self) -> torch.Tensor:
        return torch.zeros_like(self.river_outflow)

    @computed_base_field(
        description="Total inflow into river channels (m³ s⁻¹)",
        category="state",
    )
    @cached_property
    def river_inflow(self) -> torch.Tensor:
        return torch.zeros_like(self.river_outflow)

    @computed_base_field(
        description="Total flooded area (m²)",
        category="state",
    )
    @cached_property
    def flood_area(self) -> torch.Tensor:
        return torch.zeros_like(self.river_outflow)

    @computed_base_field(
        description="Fraction of catchment area that is flooded (-)",
        category="state",
    )
    @cached_property
    def flood_fraction(self) -> torch.Tensor:
        return torch.zeros_like(self.river_outflow)

    @computed_base_field(
        description="Total inflow to floodplains (m³ s⁻¹)",
        category="state",
    )
    @cached_property
    def flood_inflow(self) -> torch.Tensor:
        return torch.zeros_like(self.river_outflow)

    @computed_base_field(
        description="Total outflow from catchment (river + flood) (m³ s⁻¹)",
        category="state",
    )
    @cached_property
    def total_outflow(self) -> torch.Tensor:
        return torch.zeros_like(self.river_outflow)

    # ------------------------------------------------------------------ #
    # Post-init validation
    # ------------------------------------------------------------------ #
    @model_validator(mode="after")
    def validate_downstream_idx(self) -> Self:
        if not torch.all(
            (self.downstream_idx >= 0) & (self.downstream_idx < self.num_catchments)
        ):
            raise ValueError("downstream_idx contains invalid indices")
        return self

    @model_validator(mode="after")
    def validate_catchment_id(self) -> Self:
        if torch.unique(self.catchment_id).numel() != self.catchment_id.numel():
            raise ValueError("catchment_id must be unique")
        return self

    @model_validator(mode="after")
    def validate_num_catchments(self) -> Self:
        if self.num_catchments <= 0:
            raise ValueError("num_catchments must be positive")
        return self

    @model_validator(mode="after")
    def validate_num_flood_levels(self) -> Self:
        if self.num_flood_levels < 1:
            raise ValueError("num_flood_levels must be at least 1")
        return self

    @model_validator(mode="after")
    def validate_levee_catchment_id(self) -> Self:
        if self.levee_catchment_id is not None:
            if self.levee_catchment_id.numel() > 0:
                if torch.any(self.levee_catchment_id < 0):
                    raise ValueError("levee_catchment_id contains negative values")
        return self

    @model_validator(mode="after")
    def validate_is_river_mouth(self) -> Self:
        if not torch.all(
            self.catchment_id[self.is_river_mouth] == self.downstream_id[self.is_river_mouth]
        ):
            raise ValueError("is_river_mouth must point to self in downstream_id")
        return self

    @model_validator(mode="after")
    def validate_flood_depth_table_monotonicity(self) -> Self:
        if self.num_flood_levels > 1:
            diffs = torch.diff(self.flood_depth_table, dim=1)
            if not torch.all(diffs >= 0):
                raise ValueError("flood_depth_table must be monotonically increasing along the columns (flood levels)")
        return self

    # ------------------------------------------------------------------ #
    # Batched flags
    # ------------------------------------------------------------------ #
    @computed_field
    @cached_property
    def batched_river_manning(self) -> bool:
        return self._is_batched(self.river_manning)

    @computed_field
    @cached_property
    def batched_flood_manning(self) -> bool:
        return self._is_batched(self.flood_manning)

    @computed_field
    @cached_property
    def batched_river_width(self) -> bool:
        return self._is_batched(self.river_width)

    @computed_field
    @cached_property
    def batched_river_length(self) -> bool:
        return self._is_batched(self.river_length)

    @computed_field
    @cached_property
    def batched_river_height(self) -> bool:
        return self._is_batched(self.river_height)

    @computed_field
    @cached_property
    def batched_catchment_elevation(self) -> bool:
        return self._is_batched(self.catchment_elevation)

    @computed_field
    @cached_property
    def batched_downstream_distance(self) -> bool:
        return self._is_batched(self.downstream_distance)

    @computed_field
    @cached_property
    def batched_flood_depth_table(self) -> bool:
        return self._is_batched(self.flood_depth_table)

    @computed_field
    @cached_property
    def batched_catchment_area(self) -> bool:
        return self._is_batched(self.catchment_area)
