from __future__ import annotations

import argparse
from datetime import datetime
from pathlib import Path

import numpy as np
import pandas as pd

from .utils import ROOT, read_jsonl


def _parse_iso(s: str) -> float | None:
    try:
        # 例: 2025-08-01T13:45:12Z / 2025-08-01T13:45:12+00:00
        s = s.rstrip("Z")
        return datetime.fromisoformat(s).timestamp()
    except Exception:
        return None


def _load_logs_timestamps(path: Path) -> pd.DataFrame | None:
    if not path.exists():
        return None
    rows = []
    for rec in read_jsonl(path):
        # logs.jsonl に turn_id / session_id がある前提（テンプレ準拠）
        sid = rec.get("session_id")
        tid = rec.get("turn_id")
        ts = rec.get("timestamp")
        if sid is None or tid is None or ts is None:
            continue
        t = _parse_iso(ts)
        if t is None:
            continue
        rows.append({"session_id": sid, "turn_id": int(tid), "t_unix": float(t)})
    if not rows:
        return None
    df = pd.DataFrame(rows).drop_duplicates(["session_id", "turn_id"]).sort_values(["session_id", "turn_id"])
    return df


def _estimate_tau_for_segment(dt: np.ndarray, delta: np.ndarray) -> float | None:
    """log(Dt) = a + b * Δt から b<0 を仮定して tau = -1/b を返す。"""
    eps = 1e-9
    dt = dt.astype(float)
    delta = delta.astype(float)
    ok = np.isfinite(dt) & np.isfinite(delta) & (dt > eps) & (delta > 0)
    if ok.sum() < 3:
        return None
    y = np.log(dt[ok])
    x = delta[ok]
    # 最小二乗直線
    b, a = np.polyfit(x, y, 1)
    if b >= 0:  # 発散 or ドリフト無し
        return None
    tau = -1.0 / b
    return float(tau)


def estimate_tau_drift(
    indices_path: Path = ROOT / "reports/indices.parquet",
    events_path: Path = ROOT / "reports/alpha_phi.parquet",
    logs_path: Path = ROOT / "data/logs.jsonl",
    out_txt: Path = ROOT / "reports/tau_drift.txt",
    out_md: Path = ROOT / "reports/tau_drift.md",
) -> None:
    idx = pd.read_parquet(indices_path).sort_values(["session_id", "turn_id"]).reset_index(drop=True)
    # 呼称イベント（あれば）を読んで“静寂”区間の近似に使う
    calls = None
    if events_path.exists():
        ev = pd.read_parquet(events_path)
        calls = {(r["session_id"], int(r["turn_id"])) for _, r in ev.iterrows()}

    # Δt の算出：logs.jsonl があればそれを使う
    tsdf = _load_logs_timestamps(logs_path)
    if tsdf is not None:
        idx = idx.merge(tsdf, on=["session_id", "turn_id"], how="left")
        # セッション内差分
        idx["Delta"] = idx.groupby("session_id")["t_unix"].diff()
    else:
        # フォールバック：1.0
        idx["Delta"] = 1.0

    # 呼称フラグ（ない場合は A_t が中央値未満を静寂とみなす）
    if calls:
        idx["call_flag"] = idx.apply(lambda r: int((r["session_id"], int(r["turn_id"])) in calls), axis=1)
    else:
        medA = float(np.nanmedian(pd.to_numeric(idx.get("A_t", pd.Series([])), errors="coerce")))
        idx["call_flag"] = (pd.to_numeric(idx.get("A_t", 0.0), errors="coerce") > medA).astype(int)

    # セッションごとに“静寂”区間を抽出して τ を推定
    taus = []
    for sid, g in idx.groupby("session_id", sort=False):
        g = g.sort_values("turn_id").reset_index(drop=True)
        # 静寂 = call_flag == 0 の連続区間
        mask = (g["call_flag"] == 0)
        if mask.sum() < 3:
            continue
        dt = pd.to_numeric(g.loc[mask, "Dt"], errors="coerce").to_numpy()
        delta = pd.to_numeric(g.loc[mask, "Delta"], errors="coerce").fillna(0).to_numpy()
        tau = _estimate_tau_for_segment(dt, delta)
        if tau is not None and np.isfinite(tau) and 0 < tau < 1e9:
            taus.append({"session_id": sid, "tau_drift": tau})

    # 集計
    out_txt.parent.mkdir(parents=True, exist_ok=True)
    out_md.parent.mkdir(parents=True, exist_ok=True)

    if taus:
        df_tau = pd.DataFrame(taus)
        med = float(np.median(df_tau["tau_drift"]))
        out_txt.write_text(f"tau_drift_median ~ {med:.3f}\n", encoding="utf-8")
        out_md.write_text(
            "\n".join(
                [
                    "# Tau-drift Summary",
                    "",
                    f"- sessions with estimate: **{len(df_tau)}**",
                    f"- median(τ_drift): **{med:.3f}**",
                    "",
                ]
            ),
            encoding="utf-8",
        )
    else:
        out_txt.write_text("tau_drift_median ~ NaN\n", encoding="utf-8")
        out_md.write_text("# Tau-drift Summary\n\n- no valid silent segments found\n", encoding="utf-8")

    print(f"tau_drift -> {out_txt}; {out_md}")


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--indices", type=Path, default=ROOT / "reports/indices.parquet")
    ap.add_argument("--events", type=Path, default=ROOT / "reports/alpha_phi.parquet")
    ap.add_argument("--logs", type=Path, default=ROOT / "data/logs.jsonl")
    ap.add_argument("--txt", type=Path, default=ROOT / "reports/tau_drift.txt")
    ap.add_argument("--md", type=Path, default=ROOT / "reports/tau_drift.md")
    args = ap.parse_args()
    estimate_tau_drift(args.indices, args.events, args.logs, args.txt, args.md)


if __name__ == "__main__":
    main()