from __future__ import annotations

import argparse
from pathlib import Path

# ヘッドレスでも確実に保存
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt  # noqa: E402
import pandas as pd  # noqa: E402

from .utils import ROOT


def _plot_series(y, title: str, out_path: Path, xlabel: str = "Turn (ordered)", ylabel: str = "Score") -> None:
    out_path.parent.mkdir(parents=True, exist_ok=True)
    plt.figure()
    plt.plot(y)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)
    plt.tight_layout()
    plt.savefig(out_path)
    plt.close()


def plots(
    indices_path: Path = ROOT / "reports/indices.parquet",
    out_dir: Path = ROOT / "reports/figs",
) -> None:
    out_dir.mkdir(parents=True, exist_ok=True)

    df = pd.read_parquet(indices_path)
    # 並び順を安定化
    df = df.sort_values(["session_id", "turn_id"]).reset_index(drop=True)

    # NaN安全化
    for col in ("E_score", "H_t", "A_t"):
        if col not in df:
            raise KeyError(f"required column missing: {col}")
        df[col] = pd.to_numeric(df[col], errors="coerce").fillna(method="ffill").fillna(method="bfill")

    # 指標ごとに個別PNG
    _plot_series(df["E_score"].values, "Time series: E_score", out_dir / "timeseries_E_score.png")
    _plot_series(df["H_t"].values, "Time series: H_t", out_dir / "timeseries_H_t.png")
    _plot_series(df["A_t"].values, "Time series: A_t", out_dir / "timeseries_A_t.png")


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--indices", type=Path, default=ROOT / "reports/indices.parquet")
    ap.add_argument("--out_dir", type=Path, default=ROOT / "reports/figs")
    args = ap.parse_args()
    plots(args.indices, args.out_dir)


if __name__ == "__main__":
    main()
