"""
baselines/persistence/model.py
==============================
Persistence baseline for Task 1 (water quality forecasting).
Predicts the last observed value for all future timesteps.
This is the naive baseline against which Skill Score is computed.

Benchmark result (Task 1, paper Table 5):
  Site A RMSE: 1.31 | Site B: 1.08 | Site C: 2.94 | Fleet mean: 1.78
  Skill Score: 0.00 (reference)
"""

import numpy as np
from typing import Optional


class PersistenceBaseline:
    """
    Persistence (last-observation-carried-forward) baseline.
    Predicts the last observed value for all forecast horizons.
    """

    def predict(self, x: np.ndarray, n_steps: int = 288) -> np.ndarray:
        """
        Parameters
        ----------
        x       : (B, T_in, P) — input sequences
        n_steps : forecast horizon (default 288 = 72 h at 15-min res.)

        Returns
        -------
        (B, n_steps, P) — last observed value repeated for all horizons
        """
        last = x[:, -1:, :]                    # (B, 1, P)
        return np.repeat(last, n_steps, axis=1) # (B, n_steps, P)

    def evaluate(self, x: np.ndarray, y_true: np.ndarray) -> float:
        """Return DO RMSE at 72 h."""
        preds = self.predict(x, n_steps=y_true.shape[1])
        do_pred = preds[:, -1, 0]
        do_true = y_true[:, -1, 0]
        return float(np.sqrt(np.mean((do_pred - do_true) ** 2)))
