#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
scc_export.py
- Build empirical residue graph G_M for odd-only map T_b(n) = (3n+b)/2^{v2(3n+b)}
- Find dominant SCC by INTERNAL EDGE WEIGHT (sum of weights inside SCC)
- Report:
    dom_size: number of residues in dominant SCC
    coverage: dom_size / |V_M|  (V_M = odd residues mod M)
    dom_mass: internal_weight(dominant SCC) / total_weight(G_M)
    lift_cover: |proj(domSCC_M') ∩ domSCC_M| / |domSCC_M|
    jaccard:    |proj(domSCC_M') ∩ domSCC_M| / |proj(domSCC_M') ∪ domSCC_M|
- Default protocol: sample odd n uniformly from [1, NMAX] (odd only) using fixed seed.
"""

import argparse
import math
import random
from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional, Set, Tuple

import networkx as nx


def v2(x: int) -> int:
    k = 0
    while (x & 1) == 0:
        x >>= 1
        k += 1
    return k


def odd_only_step(n: int, b: int) -> Tuple[int, int]:
    if n % 2 == 0:
        raise ValueError("n must be odd")
    num = 3 * n + b
    k = v2(num)
    return num >> k, k


def build_residue_graph(M: int, b: int, samples: int, nmax: int, seed: int) -> nx.DiGraph:
    G = nx.DiGraph()
    # nodes = odd residues mod M
    odd_nodes = [r for r in range(M) if r % 2 == 1]
    G.add_nodes_from(odd_nodes)

    rng = random.Random(seed)
    for _ in range(samples):
        # uniformly sample an odd n in [1, nmax] (odd only)
        n = rng.randrange(1, nmax + 1, 2)
        r = n % M
        nxt, _ = odd_only_step(n, b)
        r2 = nxt % M

        if G.has_edge(r, r2):
            G[r][r2]["weight"] += 1
        else:
            G.add_edge(r, r2, weight=1)
    return G


def total_weight(G: nx.DiGraph) -> int:
    return sum(d.get("weight", 0) for _, _, d in G.edges(data=True))


def internal_weight(G: nx.DiGraph, nodes: Set[int]) -> int:
    w = 0
    # subgraph view
    for u, v, d in G.subgraph(nodes).edges(data=True):
        w += d.get("weight", 0)
    return w


def dominant_scc_by_internal_weight(G: nx.DiGraph) -> Tuple[Set[int], int]:
    best: Optional[Set[int]] = None
    best_w = -1

    for comp in nx.strongly_connected_components(G):
        comp_set = set(comp)
        w = internal_weight(G, comp_set)
        if w > best_w:
            best_w = w
            best = comp_set

    if best is None:
        return set(), 0
    return best, best_w


@dataclass
class SccMetrics:
    b: int
    M: int
    dom_size: int
    coverage: float
    dom_mass: float
    lift_cover: Optional[float] = None
    jaccard: Optional[float] = None


def proj_mod(nodes_mod_Mp: Set[int], M: int) -> Set[int]:
    return {x % M for x in nodes_mod_Mp}


def compute_metrics_for_b(
    b: int,
    Ms: List[int],
    samples: int,
    nmax: int,
    seed: int,
) -> List[SccMetrics]:
    # build graphs and dominant SCCs for each M
    graphs: Dict[int, nx.DiGraph] = {}
    doms: Dict[int, Set[int]] = {}
    dom_internal_w: Dict[int, int] = {}
    totals: Dict[int, int] = {}

    for M in Ms:
        G = build_residue_graph(M, b, samples=samples, nmax=nmax, seed=seed)
        graphs[M] = G
        dom, dom_w = dominant_scc_by_internal_weight(G)
        doms[M] = dom
        dom_internal_w[M] = dom_w
        totals[M] = total_weight(G)

    out: List[SccMetrics] = []

    for M in Ms:
        Vsize = len([r for r in range(M) if r % 2 == 1])
        dom = doms[M]
        cov = (len(dom) / Vsize) if Vsize else 0.0
        mass = (dom_internal_w[M] / totals[M]) if totals[M] else 0.0
        out.append(SccMetrics(b=b, M=M, dom_size=len(dom), coverage=cov, dom_mass=mass))

    # compute lift metrics for refinement pairs (assume Ms sorted)
    Ms_sorted = sorted(Ms)
    # for each adjacent refinement M -> M' where M' is multiple of M
    for i in range(len(Ms_sorted) - 1):
        M = Ms_sorted[i]
        Mp = Ms_sorted[i + 1]
        if Mp % M != 0:
            continue

        domM = doms[M]
        domMp = doms[Mp]
        proj = proj_mod(domMp, M)

        inter = proj & domM
        union = proj | domM

        lift_cover = (len(inter) / len(domM)) if domM else 0.0
        jacc = (len(inter) / len(union)) if union else 0.0

        # attach to the Mp row (refinement result)
        for row in out:
            if row.M == Mp:
                row.lift_cover = lift_cover
                row.jaccard = jacc
                break

    return out


def write_csv(path: str, rows: List[SccMetrics]) -> None:
    with open(path, "w", encoding="utf-8") as f:
        f.write("b,M,dom_size,coverage,dom_mass,lift_cover,jaccard\n")
        for r in rows:
            lift = "" if r.lift_cover is None else f"{r.lift_cover:.6f}"
            jac = "" if r.jaccard is None else f"{r.jaccard:.6f}"
            f.write(
                f"{r.b},{r.M},{r.dom_size},{r.coverage:.6f},{r.dom_mass:.6f},{lift},{jac}\n"
            )


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--samples", type=int, default=200_000)
    ap.add_argument("--seed", type=int, default=42)
    ap.add_argument("--nmax", type=int, default=10_000_000)
    ap.add_argument("--mods", type=str, default="36,72")
    ap.add_argument("--bs", type=str, default="1,5")
    ap.add_argument("--out", type=str, default="scc_summary.csv")
    args = ap.parse_args()

    Ms = [int(x.strip()) for x in args.mods.split(",") if x.strip()]
    bs = [int(x.strip()) for x in args.bs.split(",") if x.strip()]

    all_rows: List[SccMetrics] = []
    for b in bs:
        all_rows.extend(
            compute_metrics_for_b(
                b=b,
                Ms=Ms,
                samples=args.samples,
                nmax=args.nmax,
                seed=args.seed,
            )
        )

    # sort output
    all_rows.sort(key=lambda r: (r.b, r.M))
    write_csv(args.out, all_rows)
    print(f"[OK] wrote {args.out}")


if __name__ == "__main__":
    main()