"""Run 25 — matched-pair analysis: isolate price as the only variable.

For each fiscal year, identify pairs of stocks with similar fundamentals
(ROIC, fundamental growth g_f, optionally sector and OE growth). Within
each matched pair, test whether the stock with the higher Gap had the
higher realized 5-year forward return.

This is a *paired* design — it controls for fundamentals, business
quality, time period, and (optionally) sector. The only systematic
difference between the matched stocks is how the market priced them
(captured in Gap and MoS).

Interpretation:
- Hit rate > 50% with significance → the Brina Gap adds value once
  fundamentals are controlled for. The price signal carries
  information beyond what fundamentals already predict.
- Hit rate ≈ 50% → the Gap is noise once fundamentals are controlled
  for. The framework's univariate accuracy reflects fundamental
  differences, not the Gap signal.

We run three tightness levels:
- LOOSE: ROIC within ±10pp, g_f within ±10pp
- MEDIUM: ROIC within ±5pp, g_f within ±5pp
- STRICT: ROIC within ±3pp, g_f within ±3pp, same sector

Plus a Gap-magnitude variant: only test pairs where |Gap_A − Gap_B| > 5pp,
ensuring the price differential is meaningful.
"""

from __future__ import annotations

import csv
import random
import sys
from collections import defaultdict
from datetime import date, datetime
from math import comb
from pathlib import Path

REPO_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(REPO_ROOT))

from src.compute_forward_return import compute_forward_return, excess_return  # noqa: E402
from src.fetch_prices import PriceClient  # noqa: E402
from scripts.run_statistical_analysis import SECTOR_OF_TICKER  # noqa: E402


def _fnum(s):
    try: return float(s) if s not in (None, "") else None
    except ValueError: return None


def _binom_p_one_sided(n, k):
    if n == 0: return 1.0
    return sum(comb(n, i) for i in range(k, n + 1)) / (2 ** n)


def _bootstrap_ci(hits, n_resamples=5_000, seed=20260518):
    if not hits or len(hits) < 3:
        return float("nan"), float("nan")
    rng = random.Random(seed)
    n = len(hits)
    samples = sorted(sum(rng.choice(hits) for _ in range(n)) / n for _ in range(n_resamples))
    return samples[int(0.025 * n_resamples)], samples[int(0.975 * n_resamples)]


def _within(a, b, tol):
    return abs(a - b) <= tol


def _build_records():
    """Load file_01 and join in 5y forward returns for each VALID row."""
    rows = list(csv.DictReader((REPO_ROOT / "outputs" / "file_01_full_dataset.csv").open()))
    pc = PriceClient(cache_dir=REPO_ROOT / "data" / "prices")
    spy = pc.get_series("SPY")
    today = date.today()
    series_cache = {}
    records = []
    for r in rows:
        if r["status"].split(" |")[0].strip() != "OK": continue
        roic = _fnum(r.get("roic_pct"))
        gf = _fnum(r.get("gf_pct"))
        gap = _fnum(r.get("gap_pp"))
        mos = _fnum(r.get("mos_pct"))
        if any(v is None for v in (roic, gf, gap, mos)): continue
        if not r.get("period_end"): continue
        try: fye = datetime.strptime(r["period_end"], "%Y-%m-%d").date()
        except ValueError: continue
        ticker = r["ticker"]
        if ticker not in series_cache:
            try: series_cache[ticker] = pc.get_series(ticker)
            except Exception: series_cache[ticker] = None
        if series_cache[ticker] is None: continue
        stk_fr = compute_forward_return(series_cache[ticker], fye, 5, today)
        spy_fr = compute_forward_return(spy, fye, 5, today)
        er = excess_return(stk_fr, spy_fr)
        if er is None: continue
        records.append({
            "ticker": ticker, "fy": int(r["fiscal_year"]),
            "sector": SECTOR_OF_TICKER.get(ticker, ""),
            "roic": roic, "gf": gf, "gap": gap, "mos": mos,
            "excess_5y": er,
        })
    return records


def _matched_pair_test(records, roic_tol, gf_tol, same_sector=False,
                       gap_min_diff=0.0, label=""):
    """For each pair of records sharing a fiscal year and meeting the
    similarity criteria, predict from Gap differential. Score against
    realized excess-return differential."""
    by_year = defaultdict(list)
    for rec in records:
        by_year[rec["fy"]].append(rec)

    hits_pred_by_gap = []
    hits_pred_by_mos = []
    n_pairs = 0
    for fy, rows in by_year.items():
        for i in range(len(rows)):
            for j in range(i + 1, len(rows)):
                a, b = rows[i], rows[j]
                # Similarity filter
                if not _within(a["roic"], b["roic"], roic_tol): continue
                if not _within(a["gf"], b["gf"], gf_tol): continue
                if same_sector and a["sector"] != b["sector"]: continue
                # Skip pairs where Gap differential is too small to be meaningful
                if abs(a["gap"] - b["gap"]) < gap_min_diff: continue
                n_pairs += 1
                # Predict from Gap differential
                pred_gap = 1 if a["gap"] > b["gap"] else -1
                # Predict from MoS differential (for comparison)
                pred_mos = 1 if a["mos"] > b["mos"] else -1
                # Realized direction
                if a["excess_5y"] == b["excess_5y"]: continue  # skip exact ties
                real = 1 if a["excess_5y"] > b["excess_5y"] else -1
                hits_pred_by_gap.append(1 if pred_gap == real else 0)
                hits_pred_by_mos.append(1 if pred_mos == real else 0)

    print(f"\n  {label}")
    print(f"    Pairs tested: {n_pairs}")
    for tag, hits in [("Gap-differential predicts", hits_pred_by_gap),
                      ("MoS-differential predicts", hits_pred_by_mos)]:
        if not hits: print(f"    {tag}: no usable pairs"); continue
        n = len(hits); k = sum(hits)
        p = _binom_p_one_sided(n, k)
        lo, hi = _bootstrap_ci(hits)
        sig = "***" if p < 0.01 else ("**" if p < 0.05 else ("*" if p < 0.10 else ""))
        print(f"    {tag}: hits={k}/{n} = {k/n*100:.1f}% CI [{lo*100:.0f}%, {hi*100:.0f}%] p={p:.4f} {sig}")
    return hits_pred_by_gap, hits_pred_by_mos


def main():
    records = _build_records()
    print(f"Records (VALID rows with 5y forward returns): {len(records)}")

    print("\n" + "=" * 64)
    print(" MATCHED-PAIR ANALYSIS")
    print(" Within fiscal year, control for fundamentals, isolate price")
    print("=" * 64)

    # Level 1 — Loose: ROIC ±10pp, g_f ±10pp, any sector
    _matched_pair_test(records, roic_tol=10.0, gf_tol=10.0,
                       label="L1 LOOSE  — ROIC ±10pp, g_f ±10pp, any sector")

    # Level 2 — Medium: ROIC ±5pp, g_f ±5pp, any sector
    _matched_pair_test(records, roic_tol=5.0, gf_tol=5.0,
                       label="L2 MEDIUM — ROIC ±5pp, g_f ±5pp, any sector")

    # Level 3 — Strict: ROIC ±3pp, g_f ±3pp, SAME sector
    _matched_pair_test(records, roic_tol=3.0, gf_tol=3.0, same_sector=True,
                       label="L3 STRICT — ROIC ±3pp, g_f ±3pp, SAME sector")

    print("\n" + "=" * 64)
    print(" + GAP-DIFFERENTIAL FLOOR (require |Gap_A − Gap_B| > 5pp)")
    print(" Tests whether the framework predicts when the price signal")
    print(" is loud, not just whenever fundamentals are matched.")
    print("=" * 64)

    _matched_pair_test(records, roic_tol=10.0, gf_tol=10.0, gap_min_diff=5.0,
                       label="L1 LOOSE  + |ΔGap| > 5pp")

    _matched_pair_test(records, roic_tol=5.0, gf_tol=5.0, gap_min_diff=5.0,
                       label="L2 MEDIUM + |ΔGap| > 5pp")

    _matched_pair_test(records, roic_tol=3.0, gf_tol=3.0, same_sector=True,
                       gap_min_diff=5.0,
                       label="L3 STRICT + |ΔGap| > 5pp")

    # Aggressive Gap filter — require the framework to commit
    _matched_pair_test(records, roic_tol=10.0, gf_tol=10.0, gap_min_diff=10.0,
                       label="L1 LOOSE  + |ΔGap| > 10pp")

    print("\n" + "=" * 64)
    print(" INTERPRETATION")
    print("=" * 64)
    print(" - If Gap-differential accuracy > 50% with significance after")
    print("   matching, the Brina Gap adds price-discovery value beyond")
    print("   what fundamentals already predict.")
    print(" - If accuracy ≈ 50%, the Gap signal is noise once fundamentals")
    print("   are controlled for; the framework's univariate accuracy")
    print("   reflects fundamental differences, not the price signal.")
    print(" - Compare Gap accuracy vs MoS accuracy at each level — they")
    print("   should diverge if our earlier finding (MoS is at chance) is")
    print("   robust within matched pairs.")


if __name__ == "__main__":
    raise SystemExit(main())
