from __future__ import annotations

import argparse
from dataclasses import dataclass
from pathlib import Path

import numpy as np
import pandas as pd

from .utils import ROOT, set_seed


@dataclass
class Params:
    win: int = 3           # 移動窓（奇数推奨）
    z_thresh: float = 1.0  # Zスコア閾値（呼称スパイク）


def _detect_calls(df: pd.DataFrame, p: Params) -> pd.Series:
    """セッション内で A_t のスパイク（呼称）を検出し、bool Series を返す。"""
    x = pd.to_numeric(df["A_t"], errors="coerce").fillna(0.0)
    ma = x.rolling(p.win, center=True, min_periods=1).mean()
    sd = x.rolling(p.win, center=True, min_periods=1).std().fillna(0.0)
    z = (x - ma) / (sd.replace(0.0, np.nan))
    is_call = (z >= p.z_thresh).fillna(False)
    return is_call


def estimate_alpha_phi(
    indices_path: Path = ROOT / "reports/indices.parquet",
    events_out: Path = ROOT / "reports/alpha_phi.parquet",
    summary_out: Path = ROOT / "reports/alpha_phi_summary.md",
    win: int = 3,
    z_thresh: float = 1.0,
) -> None:
    set_seed()
    p = Params(win=win, z_thresh=z_thresh)

    df = pd.read_parquet(indices_path)
    df = df.sort_values(["session_id", "turn_id"]).reset_index(drop=True)

    # 必要列チェック
    for col in ("session_id", "turn_id", "Dt", "A_t"):
        if col not in df:
            raise KeyError(f"missing column in indices: {col}")

    # セッションごとに呼称検出
    events = []
    for sid, g in df.groupby("session_id", sort=False):
        g = g.sort_values("turn_id").reset_index(drop=True)
        call_flag = _detect_calls(g, p)
        g = g.assign(call_flag=call_flag.astype(int))

        # 呼称直前/直後の Dt を拾って α_φ を算出
        for i, is_call in enumerate(call_flag):
            if not is_call:
                continue
            if i == 0 or i == len(g) - 1:
                continue
            dt_before = float(g.loc[i - 1, "Dt"])
            dt_after = float(g.loc[i + 1, "Dt"])
            if not np.isfinite(dt_before) or not np.isfinite(dt_after):
                continue
            if dt_before <= 0:
                continue
            alpha = dt_after / dt_before
            # 異常値を弾く緩い制約
            if 0.0 < alpha < 10.0:
                events.append(
                    {
                        "session_id": sid,
                        "turn_id": int(g.loc[i, "turn_id"]),
                        "alpha_phi": float(alpha),
                        "Dt_before": dt_before,
                        "Dt_after": dt_after,
                        "A_t": float(g.loc[i, "A_t"]),
                    }
                )

    ev = pd.DataFrame(events)
    events_out.parent.mkdir(parents=True, exist_ok=True)
    if len(ev):
        ev.to_parquet(events_out, index=False)
    else:
        # 空でもファイルは作る（下流がこけないように）
        pd.DataFrame(columns=["session_id", "turn_id", "alpha_phi", "Dt_before", "Dt_after", "A_t"]).to_parquet(
            events_out, index=False
        )

    # サマリをMarkdownで
    if len(ev):
        med = float(np.median(ev["alpha_phi"]))
        n = int(len(ev))
    else:
        med, n = float("nan"), 0

    summary_out.parent.mkdir(parents=True, exist_ok=True)
    summary_out.write_text(
        "\n".join(
            [
                "# Alpha-phi Summary",
                "",
                f"- events: **{n}**",
                f"- median(alpha_phi): **{med:.4f}**" if np.isfinite(med) else "- median(alpha_phi): **NaN**",
                "",
                "_params_: "
                + f"`win={p.win}`, `z_thresh={p.z_thresh}` (spike on A_t via rolling z-score)",
            ]
        ),
        encoding="utf-8",
    )
    print(f"alpha_phi -> {events_out} (n={n}); summary -> {summary_out}")


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--indices", type=Path, default=ROOT / "reports/indices.parquet")
    ap.add_argument("--events_out", type=Path, default=ROOT / "reports/alpha_phi.parquet")
    ap.add_argument("--summary_out", type=Path, default=ROOT / "reports/alpha_phi_summary.md")
    ap.add_argument("--win", type=int, default=3)
    ap.add_argument("--z", type=float, default=1.0)
    args = ap.parse_args()
    estimate_alpha_phi(args.indices, args.events_out, args.summary_out, args.win, args.z)


if __name__ == "__main__":
    main()