"""Run 26 — head-to-head: Brina Gap vs popular value signals.

For each fiscal year, rank stocks by each value signal and form top-1/3
("cheap") and bottom-1/3 ("expensive") portfolios. Predict cheap-to-
outperform-SPY at 5y horizon. Score directional accuracy.

Signals tested on identical underlying data:
- **P/E**         = MarketCap / NetIncome           (lower = cheaper)
- **EV/EBITDA**   = EV / (EBIT + D&A)               (lower = cheaper)
- **P/B**         = MarketCap / BookEquity          (lower = cheaper)
- **FCF yield**   = (CFO − Capex) / MarketCap       (higher = cheaper)
- **Earnings yld**= NetIncome / MarketCap (=1/P/E)  (higher = cheaper)
- **EV/NOPAT**    = EV / NOPAT                       (lower = cheaper)
- **OE-DCF MoS**  = (IV − MarketCap) / IV           (higher = cheaper)
- **Brina Gap**   = g_f − g*                         (higher = cheaper)

Practitioner framing: which of these 8 signals best predicts forward
returns on the same set of US large-cap observations 2014-2024?
"""

from __future__ import annotations

import csv
import random
import sys
from collections import defaultdict
from datetime import date, datetime
from math import comb
from pathlib import Path

REPO_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(REPO_ROOT))

from src.compute_forward_return import compute_forward_return, excess_return  # noqa: E402
from src.fetch_prices import PriceClient  # noqa: E402
from src.fetch_sec import SECClient, extract_inputs  # noqa: E402
from src.compute_nopat import normalized_tax_rate  # noqa: E402


def _fnum(s):
    try: return float(s) if s not in (None, "") else None
    except ValueError: return None


def _binom_p_one_sided(n, k):
    if n == 0: return 1.0
    return sum(comb(n, i) for i in range(k, n + 1)) / (2 ** n)


def _bootstrap_ci(hits, n_resamples=5_000, seed=20260518):
    if not hits or len(hits) < 3: return float("nan"), float("nan")
    rng = random.Random(seed)
    n = len(hits)
    samples = sorted(sum(rng.choice(hits) for _ in range(n)) / n for _ in range(n_resamples))
    return samples[int(0.025 * n_resamples)], samples[int(0.975 * n_resamples)]


def _bootstrap_ci_by_year(hits_by_year, n_resamples=5_000, seed=20260518):
    """Year-clustered bootstrap: resample whole fiscal years (with replacement),
    not individual observations. Proper inference under within-year
    correlation in stock returns.

    Each resample picks Y years with replacement (where Y is the number of
    fiscal years observed); the resampled accuracy aggregates hits across
    those years."""
    years = list(hits_by_year.keys())
    if len(years) < 3: return float("nan"), float("nan")
    rng = random.Random(seed)
    Y = len(years)
    samples = []
    for _ in range(n_resamples):
        picked = [rng.choice(years) for _ in range(Y)]
        all_hits = [h for yr in picked for h in hits_by_year[yr]]
        if all_hits:
            samples.append(sum(all_hits) / len(all_hits))
    samples.sort()
    if not samples: return float("nan"), float("nan")
    n = len(samples)
    return samples[int(0.025 * n)], samples[int(0.975 * n)]


# Each signal: (extractor returning a float, direction: "higher_better" or "lower_better")
SIGNAL_DEFS = [
    ("P/E",          "lower_better",  lambda r: r["mcap"] / r["ni"] if r.get("ni") and r["ni"] > 0 else None),
    ("EV/EBITDA",    "lower_better",  lambda r: r["ev"] / (r["ebit"] + r["da"]) if r.get("ebit") is not None and r.get("da") is not None and (r["ebit"] + r["da"]) > 0 else None),
    ("P/B",          "lower_better",  lambda r: r["mcap"] / r["book_eq"] if r.get("book_eq") and r["book_eq"] > 0 else None),
    ("FCF yield",    "higher_better", lambda r: (r["cfo"] - r["capex"]) / r["mcap"] if r.get("cfo") is not None and r.get("capex") is not None and r["mcap"] > 0 else None),
    ("Earnings yld", "higher_better", lambda r: r["ni"] / r["mcap"] if r.get("ni") and r["mcap"] > 0 else None),
    ("EV/NOPAT",     "lower_better",  lambda r: r["ev"] / r["nopat"] if r.get("nopat") and r["nopat"] > 0 else None),
    ("OE-DCF MoS",   "higher_better", lambda r: r["mos"]),
    ("Brina Gap",    "higher_better", lambda r: r["gap"]),
]


def _load_records():
    """Load file_01 and join in 5y forward returns + per-row raw inputs."""
    rows = list(csv.DictReader((REPO_ROOT / "outputs" / "file_01_full_dataset.csv").open()))
    pc = PriceClient(cache_dir=REPO_ROOT / "data" / "prices")
    sec = SECClient(
        cache_dir=REPO_ROOT / "data" / "raw_sec_filings",
        user_agent="Brina Gap Research fb.95@live.it",
        offline=True,
    )
    spy = pc.get_series("SPY")
    today = date.today()
    series_cache = {}
    facts_cache = {}
    records = []
    for r in rows:
        if r["status"].split(" |")[0].strip() != "OK": continue
        gap = _fnum(r.get("gap_pp"))
        mos = _fnum(r.get("mos_pct"))
        nopat_m = _fnum(r.get("nopat_usd_m"))
        ev_b = _fnum(r.get("ev_usd_b"))
        mcap_b = _fnum(r.get("market_cap_usd_b"))
        if any(v is None for v in (gap, mos, nopat_m, ev_b, mcap_b)): continue
        if mcap_b <= 0 or ev_b <= 0 or nopat_m <= 0: continue
        if not r.get("period_end"): continue
        try: fye = datetime.strptime(r["period_end"], "%Y-%m-%d").date()
        except ValueError: continue
        ticker = r["ticker"]
        try: fy = int(r["fiscal_year"])
        except ValueError: continue

        # Pull NI, book equity, CFO, capex, D&A, EBIT from SEC cache
        if ticker not in facts_cache:
            try: facts_cache[ticker] = sec.get_facts(ticker)
            except Exception: facts_cache[ticker] = None
        if facts_cache[ticker] is None: continue
        inp = extract_inputs(facts_cache[ticker], fy)

        # Forward return
        if ticker not in series_cache:
            try: series_cache[ticker] = pc.get_series(ticker)
            except Exception: series_cache[ticker] = None
        if series_cache[ticker] is None: continue
        stk_fr = compute_forward_return(series_cache[ticker], fye, 5, today)
        spy_fr = compute_forward_return(spy, fye, 5, today)
        er = excess_return(stk_fr, spy_fr)
        if er is None: continue

        records.append({
            "ticker": ticker, "fy": fy,
            "mcap": mcap_b * 1e9,
            "ev": ev_b * 1e9,
            "nopat": nopat_m * 1e6,
            "ni": inp.net_income,
            "book_eq": inp.total_equity,
            "cfo": inp.cfo,
            "capex": inp.capex,
            "da": inp.depreciation_amortization,
            "ebit": inp.ebit,
            "gap": gap,        # already in pp
            "mos": mos,        # already in %
            "excess_5y": er,
        })
    return records


def _test_signal(records, label, direction, extractor, n_signals=1):
    """Within each fiscal year, rank stocks by the signal; top-1/3 →
    predict outperform, bottom-1/3 → predict underperform. Score
    directional accuracy.

    Reports:
    - Naive binomial test (treats observations as independent — anticonservative
      under within-year correlation).
    - Year-clustered bootstrap CI (resamples whole years, not individual
      observations — proper inference under cross-sectional dependence).
    - Bonferroni-corrected p-value (raw p × n_signals)."""
    by_year = defaultdict(list)
    for r in records:
        v = extractor(r)
        if v is None: continue
        by_year[r["fy"]].append((v, r))

    hits = []
    hits_by_year = defaultdict(list)
    for fy, rows in by_year.items():
        n = len(rows)
        if n < 6: continue  # need at least 6 rows to form non-empty terciles
        if direction == "higher_better":
            rows.sort(key=lambda x: -x[0])  # highest first = cheapest first
        else:
            rows.sort(key=lambda x: x[0])   # lowest first = cheapest first
        tercile = n // 3
        top_cheap = rows[:tercile]
        bottom_expensive = rows[-tercile:]
        for _, r in top_cheap:
            h = 1 if r["excess_5y"] > 0 else 0
            hits.append(h); hits_by_year[fy].append(h)
        for _, r in bottom_expensive:
            h = 1 if r["excess_5y"] < 0 else 0
            hits.append(h); hits_by_year[fy].append(h)

    if not hits:
        print(f"  {label:14}: no usable observations")
        return None
    n = len(hits); k = sum(hits)
    p_raw = _binom_p_one_sided(n, k)
    p_bonf = min(p_raw * n_signals, 1.0)
    lo, hi = _bootstrap_ci(hits)
    lo_yr, hi_yr = _bootstrap_ci_by_year(hits_by_year)
    sig = "***" if p_raw < 0.01 else ("**" if p_raw < 0.05 else ("*" if p_raw < 0.10 else ""))
    bonf_sig = "***" if p_bonf < 0.01 else ("**" if p_bonf < 0.05 else ("*" if p_bonf < 0.10 else ""))
    print(f"  {label:14}  n={n:>4}  acc={k/n*100:>4.1f}%  "
          f"naive CI [{lo*100:>3.0f}%, {hi*100:>3.0f}%]  "
          f"year-clustered CI [{lo_yr*100:>3.0f}%, {hi_yr*100:>3.0f}%]  "
          f"p_raw={p_raw:.4f}{sig}  p_bonf={p_bonf:.4f}{bonf_sig}")
    return {"label": label, "n": n, "k": k, "acc": k/n,
            "p_raw": p_raw, "p_bonf": p_bonf,
            "ci_naive": (lo, hi), "ci_year_clustered": (lo_yr, hi_yr)}


def _test_signal_split_sides(records, label, direction, extractor):
    """Same as _test_signal but reports long/short side separately."""
    by_year = defaultdict(list)
    for r in records:
        v = extractor(r)
        if v is None: continue
        by_year[r["fy"]].append((v, r))

    long_hits = []; short_hits = []
    for fy, rows in by_year.items():
        n = len(rows)
        if n < 6: continue
        if direction == "higher_better":
            rows.sort(key=lambda x: -x[0])
        else:
            rows.sort(key=lambda x: x[0])
        tercile = n // 3
        for _, r in rows[:tercile]:
            long_hits.append(1 if r["excess_5y"] > 0 else 0)
        for _, r in rows[-tercile:]:
            short_hits.append(1 if r["excess_5y"] < 0 else 0)

    all_hits = long_hits + short_hits
    if not all_hits: return
    n = len(all_hits); k = sum(all_hits); p = _binom_p_one_sided(n, k)
    lo, hi = _bootstrap_ci(all_hits)
    sig = "***" if p < 0.01 else ("**" if p < 0.05 else ("*" if p < 0.10 else ""))
    print(f"  {label:14}  total n={n:>3}  acc={k/n*100:>4.1f}% [{lo*100:>3.0f}%, {hi*100:>3.0f}%] p={p:.4f} {sig}     "
          f"LONG (cheap) n={len(long_hits):>3} acc={sum(long_hits)/len(long_hits)*100 if long_hits else 0:>3.0f}%     "
          f"SHORT (expensive) n={len(short_hits):>3} acc={sum(short_hits)/len(short_hits)*100 if short_hits else 0:>3.0f}%")


def _power_analysis(n, alpha=0.05, beta=0.20):
    """Minimum detectable effect size above 50% chance baseline.

    Uses normal approximation to the binomial: at n observations, a one-sided
    test at alpha rejects H0:p=0.5 when k/n > 0.5 + z_alpha * sqrt(0.25/n).
    The minimum effect with power (1 − beta) is then
       p_alt - 0.5 = (z_alpha + z_beta) * sqrt(0.25/n)
    where z_alpha=1.645 (one-sided α=0.05), z_beta=0.842 (β=0.20).
    """
    from math import sqrt
    z_alpha = 1.645
    z_beta = 0.842
    return 0.5 + (z_alpha + z_beta) * sqrt(0.25 / n)


def main():
    records = _load_records()
    n_signals = len(SIGNAL_DEFS)
    print(f"Records (VALID rows with 5y forward returns + all primitives): {len(records)}")

    print()
    print("=" * 110)
    print(" HEAD-TO-HEAD: 8 value signals on identical data")
    print(" Tercile method: top-1/3 cheap → predict outperform, bottom-1/3 expensive → underperform")
    print(f" Bonferroni correction: p_bonf = min(p_raw × {n_signals}, 1.0)")
    print("=" * 110)
    results = []
    for label, direction, extractor in SIGNAL_DEFS:
        r = _test_signal(records, label, direction, extractor, n_signals=n_signals)
        if r: results.append(r)

    print()
    print("=" * 110)
    print(" POWER ANALYSIS (minimum detectable accuracy above chance)")
    print("=" * 110)
    for r in results:
        mde = _power_analysis(r["n"])
        sig_mde = "✓" if r["acc"] >= mde else " "
        print(f"  {r['label']:14}  n={r['n']:>4}  acc={r['acc']*100:>4.1f}%  "
              f"MDE_at_80%_power={mde*100:>4.1f}%  {sig_mde} above MDE")

    print()
    print("=" * 110)
    print(" SAME COMPARISON, split by long-side / short-side")
    print("=" * 110)
    for label, direction, extractor in SIGNAL_DEFS:
        _test_signal_split_sides(records, label, direction, extractor)

    print()
    print("=" * 110)
    print(" INTERPRETATION GUIDE")
    print("=" * 110)
    print(" - Tercile-based signal test is the conventional factor-portfolio framework.")
    print(" - All signals tested on the same set of underlying observations.")
    print(" - p_raw: naive binomial p-value (treats observations as i.i.d.).")
    print(" - p_bonf: Bonferroni-corrected for 8 simultaneous tests.")
    print(" - naive CI: per-observation bootstrap (anticonservative; assumes independence).")
    print(" - year-clustered CI: per-year bootstrap (proper inference under within-year")
    print("   correlation in stock returns; the methodologically correct interval).")
    print(" - MDE = minimum detectable effect size at α=0.05 (one-sided) and β=0.20.")
    print("   Signals with acc < MDE are underpowered at this sample size.")
    print(" - Practitioner takeaway: a signal that survives Bonferroni AND has a")
    print("   year-clustered CI excluding 50% is robust to standard reviewer objections.")


if __name__ == "__main__":
    raise SystemExit(main())
