"""Run 26 — head-to-head: Brina Gap vs popular value signals.

For each fiscal year, rank stocks by each value signal and form top-1/3
("cheap") and bottom-1/3 ("expensive") portfolios. Predict cheap-to-
outperform-SPY at 5y horizon. Score directional accuracy.

Signals tested on identical underlying data:
- **P/E**         = MarketCap / NetIncome           (lower = cheaper)
- **EV/EBITDA**   = EV / (EBIT + D&A)               (lower = cheaper)
- **P/B**         = MarketCap / BookEquity          (lower = cheaper)
- **FCF yield**   = (CFO − Capex) / MarketCap       (higher = cheaper)
- **Earnings yld**= NetIncome / MarketCap (=1/P/E)  (higher = cheaper)
- **EV/NOPAT**    = EV / NOPAT                       (lower = cheaper)
- **OE-DCF MoS**  = (IV − MarketCap) / IV           (higher = cheaper)
- **Brina Gap**   = g_f − g*                         (higher = cheaper)

Practitioner framing: which of these 8 signals best predicts forward
returns on the same set of US large-cap observations 2014-2024?
"""

from __future__ import annotations

import csv
import random
import sys
from collections import defaultdict
from datetime import date, datetime
from math import comb
from pathlib import Path

REPO_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(REPO_ROOT))

from src.compute_forward_return import compute_forward_return, excess_return  # noqa: E402
from src.fetch_prices import PriceClient  # noqa: E402
from src.fetch_sec import SECClient, extract_inputs  # noqa: E402
from src.compute_nopat import normalized_tax_rate  # noqa: E402


def _fnum(s):
    try: return float(s) if s not in (None, "") else None
    except ValueError: return None


def _binom_p_one_sided(n, k):
    if n == 0: return 1.0
    return sum(comb(n, i) for i in range(k, n + 1)) / (2 ** n)


def _bootstrap_ci(hits, n_resamples=5_000, seed=20260518):
    if not hits or len(hits) < 3: return float("nan"), float("nan")
    rng = random.Random(seed)
    n = len(hits)
    samples = sorted(sum(rng.choice(hits) for _ in range(n)) / n for _ in range(n_resamples))
    return samples[int(0.025 * n_resamples)], samples[int(0.975 * n_resamples)]


# Each signal: (extractor returning a float, direction: "higher_better" or "lower_better")
SIGNAL_DEFS = [
    ("P/E",          "lower_better",  lambda r: r["mcap"] / r["ni"] if r.get("ni") and r["ni"] > 0 else None),
    ("EV/EBITDA",    "lower_better",  lambda r: r["ev"] / (r["ebit"] + r["da"]) if r.get("ebit") is not None and r.get("da") is not None and (r["ebit"] + r["da"]) > 0 else None),
    ("P/B",          "lower_better",  lambda r: r["mcap"] / r["book_eq"] if r.get("book_eq") and r["book_eq"] > 0 else None),
    ("FCF yield",    "higher_better", lambda r: (r["cfo"] - r["capex"]) / r["mcap"] if r.get("cfo") is not None and r.get("capex") is not None and r["mcap"] > 0 else None),
    ("Earnings yld", "higher_better", lambda r: r["ni"] / r["mcap"] if r.get("ni") and r["mcap"] > 0 else None),
    ("EV/NOPAT",     "lower_better",  lambda r: r["ev"] / r["nopat"] if r.get("nopat") and r["nopat"] > 0 else None),
    ("OE-DCF MoS",   "higher_better", lambda r: r["mos"]),
    ("Brina Gap",    "higher_better", lambda r: r["gap"]),
]


def _load_records():
    """Load file_01 and join in 5y forward returns + per-row raw inputs."""
    rows = list(csv.DictReader((REPO_ROOT / "outputs" / "file_01_full_dataset.csv").open()))
    pc = PriceClient(cache_dir=REPO_ROOT / "data" / "prices")
    sec = SECClient(
        cache_dir=REPO_ROOT / "data" / "raw_sec_filings",
        user_agent="Brina Gap Research fb.95@live.it",
        offline=True,
    )
    spy = pc.get_series("SPY")
    today = date.today()
    series_cache = {}
    facts_cache = {}
    records = []
    for r in rows:
        if r["status"].split(" |")[0].strip() != "OK": continue
        gap = _fnum(r.get("gap_pp"))
        mos = _fnum(r.get("mos_pct"))
        nopat_m = _fnum(r.get("nopat_usd_m"))
        ev_b = _fnum(r.get("ev_usd_b"))
        mcap_b = _fnum(r.get("market_cap_usd_b"))
        if any(v is None for v in (gap, mos, nopat_m, ev_b, mcap_b)): continue
        if mcap_b <= 0 or ev_b <= 0 or nopat_m <= 0: continue
        if not r.get("period_end"): continue
        try: fye = datetime.strptime(r["period_end"], "%Y-%m-%d").date()
        except ValueError: continue
        ticker = r["ticker"]
        try: fy = int(r["fiscal_year"])
        except ValueError: continue

        # Pull NI, book equity, CFO, capex, D&A, EBIT from SEC cache
        if ticker not in facts_cache:
            try: facts_cache[ticker] = sec.get_facts(ticker)
            except Exception: facts_cache[ticker] = None
        if facts_cache[ticker] is None: continue
        inp = extract_inputs(facts_cache[ticker], fy)

        # Forward return
        if ticker not in series_cache:
            try: series_cache[ticker] = pc.get_series(ticker)
            except Exception: series_cache[ticker] = None
        if series_cache[ticker] is None: continue
        stk_fr = compute_forward_return(series_cache[ticker], fye, 5, today)
        spy_fr = compute_forward_return(spy, fye, 5, today)
        er = excess_return(stk_fr, spy_fr)
        if er is None: continue

        records.append({
            "ticker": ticker, "fy": fy,
            "mcap": mcap_b * 1e9,
            "ev": ev_b * 1e9,
            "nopat": nopat_m * 1e6,
            "ni": inp.net_income,
            "book_eq": inp.total_equity,
            "cfo": inp.cfo,
            "capex": inp.capex,
            "da": inp.depreciation_amortization,
            "ebit": inp.ebit,
            "gap": gap,        # already in pp
            "mos": mos,        # already in %
            "excess_5y": er,
        })
    return records


def _test_signal(records, label, direction, extractor):
    """Within each fiscal year, rank stocks by the signal; top-1/3 →
    predict outperform, bottom-1/3 → predict underperform. Score
    directional accuracy."""
    by_year = defaultdict(list)
    for r in records:
        v = extractor(r)
        if v is None: continue
        by_year[r["fy"]].append((v, r))

    hits = []
    for fy, rows in by_year.items():
        n = len(rows)
        if n < 6: continue  # need at least 6 rows to form non-empty terciles
        if direction == "higher_better":
            rows.sort(key=lambda x: -x[0])  # highest first = cheapest first
        else:
            rows.sort(key=lambda x: x[0])   # lowest first = cheapest first
        tercile = n // 3
        top_cheap = rows[:tercile]
        bottom_expensive = rows[-tercile:]
        for _, r in top_cheap:
            hits.append(1 if r["excess_5y"] > 0 else 0)
        for _, r in bottom_expensive:
            hits.append(1 if r["excess_5y"] < 0 else 0)

    if not hits:
        print(f"  {label:14}: no usable observations")
        return
    n = len(hits); k = sum(hits)
    p = _binom_p_one_sided(n, k)
    lo, hi = _bootstrap_ci(hits)
    sig = "***" if p < 0.01 else ("**" if p < 0.05 else ("*" if p < 0.10 else ""))
    print(f"  {label:14}  n={n:>4}  hits={k:>3}  acc={k/n*100:>4.1f}%  CI [{lo*100:>3.0f}%, {hi*100:>3.0f}%]  p={p:.4f} {sig}")


def _test_signal_split_sides(records, label, direction, extractor):
    """Same as _test_signal but reports long/short side separately."""
    by_year = defaultdict(list)
    for r in records:
        v = extractor(r)
        if v is None: continue
        by_year[r["fy"]].append((v, r))

    long_hits = []; short_hits = []
    for fy, rows in by_year.items():
        n = len(rows)
        if n < 6: continue
        if direction == "higher_better":
            rows.sort(key=lambda x: -x[0])
        else:
            rows.sort(key=lambda x: x[0])
        tercile = n // 3
        for _, r in rows[:tercile]:
            long_hits.append(1 if r["excess_5y"] > 0 else 0)
        for _, r in rows[-tercile:]:
            short_hits.append(1 if r["excess_5y"] < 0 else 0)

    all_hits = long_hits + short_hits
    if not all_hits: return
    n = len(all_hits); k = sum(all_hits); p = _binom_p_one_sided(n, k)
    lo, hi = _bootstrap_ci(all_hits)
    sig = "***" if p < 0.01 else ("**" if p < 0.05 else ("*" if p < 0.10 else ""))
    print(f"  {label:14}  total n={n:>3}  acc={k/n*100:>4.1f}% [{lo*100:>3.0f}%, {hi*100:>3.0f}%] p={p:.4f} {sig}     "
          f"LONG (cheap) n={len(long_hits):>3} acc={sum(long_hits)/len(long_hits)*100 if long_hits else 0:>3.0f}%     "
          f"SHORT (expensive) n={len(short_hits):>3} acc={sum(short_hits)/len(short_hits)*100 if short_hits else 0:>3.0f}%")


def main():
    records = _load_records()
    print(f"Records (VALID rows with 5y forward returns + all primitives): {len(records)}")

    print()
    print("=" * 90)
    print(" HEAD-TO-HEAD: 8 value signals on identical data, 2014-2024")
    print(" Tercile method: top-1/3 cheap → predict outperform, bottom-1/3 expensive → underperform")
    print("=" * 90)
    for label, direction, extractor in SIGNAL_DEFS:
        _test_signal(records, label, direction, extractor)

    print()
    print("=" * 90)
    print(" SAME COMPARISON, split by long-side / short-side")
    print("=" * 90)
    for label, direction, extractor in SIGNAL_DEFS:
        _test_signal_split_sides(records, label, direction, extractor)

    print()
    print("=" * 90)
    print(" INTERPRETATION GUIDE")
    print("=" * 90)
    print(" - Tercile-based signal test is the conventional factor-portfolio framework.")
    print(" - All signals tested on the same set of underlying observations.")
    print(" - 'Higher accuracy' = the signal better identifies outperformers (long side)")
    print("   AND/OR better identifies underperformers (short side).")
    print(" - Practitioner takeaway: if Brina Gap is on top, you have a compelling")
    print("   improvement over the conventional value signals readers already trust.")


if __name__ == "__main__":
    raise SystemExit(main())
