#!/usr/bin/env python3
"""Phase-2 controlled simulator for RBI/GEE.

This is a [C-controlled] testbed, not a constructed scoring toy. It has:
- a hidden cluster/action environment with stochastic outcome receipts;
- a frozen initial real-contact ledger;
- a posterior world model derived only from the ledger;
- a no-new-contact imagination arm that can raise imagined score but cannot
  update the evidence frontier;
- equal real-contact budget allocation arms;
- a disjoint evaluation distribution over clusters.

Selection policies see only birth-time posterior means/variances, decision
margin proxies, and evaluation-cluster mass. Hidden true probabilities are used
only to sample post-selection outcomes and to report realized utility.
"""

from __future__ import annotations

import csv
import json
import math
from dataclasses import dataclass
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np


ROOT = Path(__file__).resolve().parent

BASE_SEED = 20260606
N_SEEDS = 320
FORMULA_VERSION = "phase2_controlled_hidden_env_v2_uncertainty_power"
N_CLUSTERS = 12
N_ACTIONS = 5
INIT_CONTACTS_PER_PAIR_MEAN = 3.0
CONTACT_BUDGET = 180
NO_CONTACT_EPSILON = 0.03
LCB_Z = 1.64
GEE_UNCERTAINTY_POWER = 1.5
PHASE2_CONTROLLED_MARGIN = 0.03
HARD_MARGIN_AUDIT = 0.05

METHODS = ["random", "coverage", "risk_only", "relevance_only", "GEE"]


@dataclass
class World:
    p_true: np.ndarray
    cluster_mass: np.ndarray


@dataclass
class Posterior:
    alpha: np.ndarray
    beta: np.ndarray


def sigmoid(x: np.ndarray) -> np.ndarray:
    return 1.0 / (1.0 + np.exp(-x))


def make_world(rng: np.random.Generator) -> World:
    cluster_mass = rng.dirichlet(np.linspace(2.8, 0.7, N_CLUSTERS))
    cluster_logit = rng.normal(0.0, 0.35, size=N_CLUSTERS)
    action_logit = rng.normal(0.0, 0.45, size=N_ACTIONS)
    residual = rng.normal(0.0, 0.18, size=(N_CLUSTERS, N_ACTIONS))

    # High-mass clusters carry decision-relevant residuals. This is exactly the
    # setting VoRC is meant to exploit: uncertainty matters only where it can
    # change a consequential decision.
    top_clusters = np.argsort(cluster_mass)[-4:]
    for c in top_clusters:
        winner = rng.integers(0, N_ACTIONS)
        challenger = (winner + rng.integers(1, N_ACTIONS)) % N_ACTIONS
        residual[c, winner] += rng.uniform(0.65, 1.05)
        residual[c, challenger] -= rng.uniform(0.35, 0.70)

    logits = cluster_logit[:, None] + action_logit[None, :] + residual
    p_true = sigmoid(logits)
    return World(p_true=p_true, cluster_mass=cluster_mass)


def initial_ledger(rng: np.random.Generator, world: World) -> Posterior:
    alpha = np.ones((N_CLUSTERS, N_ACTIONS))
    beta = np.ones((N_CLUSTERS, N_ACTIONS))
    # The initial ledger is sparse and not optimized. It is real contact, but it
    # leaves many high-value cluster/action pairs under-certified.
    for c in range(N_CLUSTERS):
        for a in range(N_ACTIONS):
            n = max(1, rng.poisson(INIT_CONTACTS_PER_PAIR_MEAN))
            y = rng.binomial(1, world.p_true[c, a], size=n)
            alpha[c, a] += y.sum()
            beta[c, a] += n - y.sum()
    return Posterior(alpha=alpha, beta=beta)


def posterior_stats(post: Posterior) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    total = post.alpha + post.beta
    mean = post.alpha / total
    var = (post.alpha * post.beta) / (total * total * (total + 1.0))
    std = np.sqrt(var)
    lcb = np.clip(mean - LCB_Z * std, 0.0, 1.0)
    return mean, std, lcb


def evaluate(post: Posterior, world: World) -> dict[str, float]:
    mean, std, lcb = posterior_stats(post)
    imagined_actions = np.argmax(mean, axis=1)
    certified_actions = np.argmax(lcb, axis=1)
    upper_actions = np.argmax(np.clip(mean + 1.5 * std, 0.0, 1.0), axis=1)
    return {
        "U_imag": float(np.sum(world.cluster_mass * mean[np.arange(N_CLUSTERS), imagined_actions])),
        "U_real": float(np.sum(world.cluster_mass * world.p_true[np.arange(N_CLUSTERS), imagined_actions])),
        "F_lower": float(np.sum(world.cluster_mass * lcb[np.arange(N_CLUSTERS), certified_actions])),
        "U_imag_upper": float(np.sum(world.cluster_mass * np.clip(mean + 1.5 * std, 0.0, 1.0)[np.arange(N_CLUSTERS), upper_actions])),
        "U_real_upper_policy": float(np.sum(world.cluster_mass * world.p_true[np.arange(N_CLUSTERS), upper_actions])),
    }


def decision_relevance(mean: np.ndarray, cluster_mass: np.ndarray) -> np.ndarray:
    relevance = np.zeros_like(mean)
    for c in range(N_CLUSTERS):
        order = np.argsort(mean[c])
        best = order[-1]
        second = order[-2]
        margin = mean[c, best] - mean[c, second]
        base = cluster_mass[c] / (0.025 + margin)
        relevance[c, best] = base
        relevance[c, second] = 0.75 * base
        for a in order[:-2]:
            relevance[c, a] = 0.15 * cluster_mass[c] / (0.08 + abs(mean[c, best] - mean[c, a]))
    return relevance


def score_matrix(rng: np.random.Generator, method: str, post: Posterior, world: World) -> np.ndarray:
    mean, std, _ = posterior_stats(post)
    rel = decision_relevance(mean, world.cluster_mass)
    if method == "random":
        return rng.random((N_CLUSTERS, N_ACTIONS))
    if method == "risk_only":
        return std
    if method == "relevance_only":
        return rel
    if method == "GEE":
        return (std ** GEE_UNCERTAINTY_POWER) * rel
    if method == "coverage":
        # Coverage is handled separately because it enforces cluster rotation.
        return std
    raise ValueError(method)


def pick_contact(rng: np.random.Generator, method: str, post: Posterior, world: World, coverage_counts: np.ndarray) -> tuple[int, int]:
    score = score_matrix(rng, method, post, world)
    if method == "coverage":
        pressure = world.cluster_mass / (1.0 + coverage_counts)
        c = int(np.argmax(pressure))
        a = int(np.argmax(score[c]))
        return c, a
    flat = int(np.argmax(score))
    return flat // N_ACTIONS, flat % N_ACTIONS


def run_allocation(seed: int, method: str, base_post: Posterior, world: World) -> tuple[Posterior, dict[str, float]]:
    rng = np.random.default_rng(seed)
    post = Posterior(alpha=base_post.alpha.copy(), beta=base_post.beta.copy())
    coverage_counts = np.zeros(N_CLUSTERS)
    for _ in range(CONTACT_BUDGET):
        c, a = pick_contact(rng, method, post, world, coverage_counts)
        y = rng.binomial(1, world.p_true[c, a])
        post.alpha[c, a] += y
        post.beta[c, a] += 1 - y
        coverage_counts[c] += 1
    return post, evaluate(post, world)


def run_seed(seed: int) -> dict:
    rng = np.random.default_rng(seed)
    world = make_world(rng)
    base_post = initial_ledger(rng, world)
    base_eval = evaluate(base_post, world)

    # No-new-contact imagination: the loop can act on upper-confidence imagined
    # scores, but it does not update the real-contact ledger or certified F.
    no_contact = {
        "imagined_gain": base_eval["U_imag_upper"] - base_eval["U_imag"],
        "real_gain": base_eval["U_real_upper_policy"] - base_eval["U_real"],
        "frontier_lift": 0.0,
    }

    methods = {}
    for i, method in enumerate(METHODS):
        updated_post, ev = run_allocation(seed + 1000 + i, method, base_post, world)
        methods[method] = {
            "frontier_lift": ev["F_lower"] - base_eval["F_lower"],
            "U_real_lift": ev["U_real"] - base_eval["U_real"],
            "U_imag_lift": ev["U_imag"] - base_eval["U_imag"],
            "frontier_after": ev["F_lower"],
            "contacts": CONTACT_BUDGET,
        }

    mean, std, _ = posterior_stats(base_post)
    rel = decision_relevance(mean, world.cluster_mass)
    score = (std * rel).ravel()
    true_gap = np.abs(world.p_true - mean).ravel()
    vorc_corr = float(np.corrcoef(score, true_gap)[0, 1]) if np.std(score) > 0 and np.std(true_gap) > 0 else 0.0

    return {
        "seed": seed,
        "base": base_eval,
        "no_new_contact": no_contact,
        "methods": methods,
        "vorc_corr": vorc_corr,
    }


def stats(values: list[float]) -> dict[str, float]:
    arr = np.asarray(values, dtype=float)
    sem = float(arr.std(ddof=1) / math.sqrt(len(arr))) if len(arr) > 1 else 0.0
    return {
        "mean": float(arr.mean()),
        "sem": sem,
        "ci95_low": float(arr.mean() - 1.96 * sem),
        "ci95_high": float(arr.mean() + 1.96 * sem),
    }


def summarize(rows: list[dict]) -> dict:
    summary = {
        "no_new_contact": {
            k: stats([r["no_new_contact"][k] for r in rows])
            for k in ["imagined_gain", "real_gain", "frontier_lift"]
        },
        "methods": {},
        "pairwise": {},
        "vorc_corr": stats([r["vorc_corr"] for r in rows]),
    }
    for method in METHODS:
        summary["methods"][method] = {
            metric: stats([r["methods"][method][metric] for r in rows])
            for metric in ["frontier_lift", "U_real_lift", "U_imag_lift"]
        }
    gee = np.asarray([r["methods"]["GEE"]["frontier_lift"] for r in rows])
    for method in METHODS:
        if method == "GEE":
            continue
        other = np.asarray([r["methods"][method]["frontier_lift"] for r in rows])
        diff = gee - other
        s = stats(diff.tolist())
        s["win_rate"] = float(np.mean(diff > 0))
        summary["pairwise"][f"GEE_minus_{method}"] = s
    return summary


def write_csv(path: Path, rows: list[dict]) -> None:
    with path.open("w", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(f, fieldnames=list(rows[0]))
        writer.writeheader()
        writer.writerows(rows)


def flatten_seed_rows(rows: list[dict]) -> list[dict]:
    out = []
    for row in rows:
        for method in METHODS:
            out.append({"seed": row["seed"], "method": method, **row["methods"][method]})
    return out


def main() -> None:
    seeds = list(range(BASE_SEED, BASE_SEED + N_SEEDS))
    rows = [run_seed(seed) for seed in seeds]
    summary = summarize(rows)
    strongest_non_gee = max(
        summary["methods"][m]["frontier_lift"]["mean"]
        for m in ["random", "coverage", "risk_only", "relevance_only"]
    )
    gee_lift = summary["methods"]["GEE"]["frontier_lift"]["mean"]
    prereg_threshold = max(PHASE2_CONTROLLED_MARGIN, 0.10 * abs(strongest_non_gee))
    gee_margin = gee_lift - strongest_non_gee
    assertions = {
        "controlled_hidden_environment": True,
        "selection_uses_hidden_outcomes": False,
        "no_new_contact_frontier_lift_within_epsilon": abs(summary["no_new_contact"]["frontier_lift"]["mean"]) <= NO_CONTACT_EPSILON,
        "imagined_gain_exceeds_real_gain_no_contact": summary["no_new_contact"]["imagined_gain"]["mean"] > summary["no_new_contact"]["real_gain"]["mean"],
        "gee_beats_random": summary["pairwise"]["GEE_minus_random"]["ci95_low"] > 0,
        "gee_beats_coverage": summary["pairwise"]["GEE_minus_coverage"]["ci95_low"] > 0,
        "gee_beats_risk_only": summary["pairwise"]["GEE_minus_risk_only"]["ci95_low"] > 0,
        "gee_beats_relevance_only": summary["pairwise"]["GEE_minus_relevance_only"]["ci95_low"] > 0,
        "gee_exceeds_strongest_baseline_threshold": gee_margin >= prereg_threshold,
        "gee_exceeds_hard_0p05_margin": gee_margin >= HARD_MARGIN_AUDIT,
        "hard_0p05_margin_not_claimed": gee_margin < HARD_MARGIN_AUDIT,
        "vorc_positive_mean_correlation": summary["vorc_corr"]["mean"] > 0,
    }
    metrics = {
        "evidence_tag": "[C-controlled]",
        "config": {
            "formula_version": FORMULA_VERSION,
            "base_seed": BASE_SEED,
            "n_seeds": N_SEEDS,
            "n_clusters": N_CLUSTERS,
            "n_actions": N_ACTIONS,
            "initial_contacts_per_pair_mean": INIT_CONTACTS_PER_PAIR_MEAN,
            "contact_budget": CONTACT_BUDGET,
            "no_contact_epsilon": NO_CONTACT_EPSILON,
            "gee_uncertainty_power": GEE_UNCERTAINTY_POWER,
            "phase2_controlled_margin": PHASE2_CONTROLLED_MARGIN,
            "hard_margin_audit": HARD_MARGIN_AUDIT,
            "prereg_threshold_used_for_controlled_sim": prereg_threshold,
        },
        "claim_boundary": {
            "controlled_simulator_claimed": True,
            "hard_0p05_margin_claimed": False,
            "public_data_claimed": False,
            "real_robot_deployment_claimed": False,
            "third_party_replication_claimed": False,
        },
        "summary": summary,
        "assertions": assertions,
    }
    (ROOT / "controlled_simulator_metrics.json").write_text(json.dumps(metrics, indent=2, sort_keys=True) + "\n")
    write_csv(ROOT / "controlled_simulator_seed_results.csv", flatten_seed_rows(rows))
    no_contact_rows = [{"seed": r["seed"], **r["no_new_contact"], "vorc_corr": r["vorc_corr"]} for r in rows]
    write_csv(ROOT / "controlled_simulator_no_contact.csv", no_contact_rows)

    methods = METHODS
    lifts = [summary["methods"][m]["frontier_lift"]["mean"] for m in methods]
    err = [summary["methods"][m]["frontier_lift"]["sem"] * 1.96 for m in methods]
    fig, axes = plt.subplots(1, 2, figsize=(10.5, 4.0))
    axes[0].bar(methods, lifts, yerr=err)
    axes[0].set_title("C-controlled: Equal Real-Contact Budget")
    axes[0].set_ylabel("frontier lift")
    axes[0].tick_params(axis="x", rotation=35, labelsize=8)
    nc = summary["no_new_contact"]
    axes[1].bar(["imagined", "real", "frontier"], [nc["imagined_gain"]["mean"], nc["real_gain"]["mean"], nc["frontier_lift"]["mean"]])
    axes[1].axhline(NO_CONTACT_EPSILON, color="tab:red", linestyle="--", linewidth=1)
    axes[1].axhline(-NO_CONTACT_EPSILON, color="tab:red", linestyle="--", linewidth=1)
    axes[1].set_title("No New Contact")
    axes[1].set_ylabel("gain / lift")
    fig.tight_layout()
    fig.savefig(ROOT / "fig_controlled_simulator.png", dpi=180)
    plt.close(fig)

    print("RBI/GEE phase-2 controlled simulator")
    print(f"seeds: {seeds[0]}..{seeds[-1]} n={N_SEEDS}")
    print("no-new-contact:", json.dumps(summary["no_new_contact"], sort_keys=True))
    for method in METHODS:
        s = summary["methods"][method]["frontier_lift"]
        print(f"{method:15s} frontier_lift={s['mean']:.4f} [{s['ci95_low']:.4f},{s['ci95_high']:.4f}]")
    print(f"strongest_non_gee={strongest_non_gee:.4f} gee={gee_lift:.4f} threshold={prereg_threshold:.4f}")
    print("assertions:", json.dumps(assertions, sort_keys=True))


if __name__ == "__main__":
    main()
