#!/usr/bin/env python3
"""Build real-data tables for the Crohn's causal-circuit figures.

Data sources used directly in this script:
1) Open Targets Platform GraphQL API (Crohn's disease EFO_0000384)
2) STRING v12 API for functional coupling between node genes
3) Public CELLxGENE ileal single-cell datasets with Crohn/normal labels:
   - output/data/ti_epithelial_crohns_normal.h5ad
   - output/data/ti_immune_crohns_normal.h5ad
   - output/data/ti_stromal_crohns_normal.h5ad

Outputs:
- output/tables/node_rank_table.csv
- output/tables/edge_evidence_scores.csv
- output/tables/phenotype_mapping_scores.csv
- output/data/*.csv intermediate artifacts
"""

from __future__ import annotations

import json
import math
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, List, Tuple

import anndata as ad
import numpy as np
import pandas as pd
import requests
from scipy import sparse as sp


ROOT = Path(__file__).resolve().parents[2]
DATA_DIR = ROOT / "data"
TABLE_DIR = ROOT / "tables"

DATA_DIR.mkdir(parents=True, exist_ok=True)
TABLE_DIR.mkdir(parents=True, exist_ok=True)

OT_URL = "https://api.platform.opentargets.org/api/v4/graphql"
STRING_URL = "https://string-db.org/api/json/network"
CROHNS_EFO = "EFO_0000384"

SCRNA_DATASETS = {
    "epithelial": DATA_DIR / "ti_epithelial_crohns_normal.h5ad",
    "immune": DATA_DIR / "ti_immune_crohns_normal.h5ad",
    "stromal": DATA_DIR / "ti_stromal_crohns_normal.h5ad",
}

NODE_GENES: Dict[str, List[str]] = {
    "NOD2": ["NOD2"],
    "ATG16L1/IRGM": ["ATG16L1", "IRGM"],
    "XBP1": ["XBP1"],
    "IL23R": ["IL23R"],
    "MUC2": ["MUC2"],
}

MODULE_GENES = {
    "nod2_module": ["NOD2"],
    "autophagy_module": ["ATG16L1", "IRGM", "SQSTM1", "MAP1LC3B"],
    "upr_module": ["XBP1", "HSPA5", "DDIT3", "ERN1"],
    "il23_module": ["IL23R", "IL17A", "IL22", "RORC"],
    "mucus_module": ["MUC2", "AGR2", "SPDEF", "FCGBP"],
    "fibrosis_signature": ["COL1A1", "COL3A1", "ACTA2", "TAGLN", "TGFB1"],
    "tnf_signature": ["TNF", "TNFRSF1A", "NFKB1", "RELA"],
    "il23_signature": ["IL23R", "IL17A", "IL22", "RORC"],
}

EDGE_ORDER: List[Tuple[str, str]] = [
    ("NOD2", "ATG16L1/IRGM"),
    ("ATG16L1/IRGM", "XBP1"),
    ("XBP1", "IL23R"),
    ("IL23R", "MUC2"),
    ("MUC2", "NOD2"),
]

EDGE_EXPECTED_SIGN = {
    "NOD2->ATG16L1/IRGM": "positive",
    "ATG16L1/IRGM->XBP1": "positive",
    "XBP1->IL23R": "positive",
    "IL23R->MUC2": "positive",
    "MUC2->NOD2": "negative",  # barrier loss feedback logic
}

EDGE_MODULE_MAP = {
    "NOD2": "nod2_module",
    "ATG16L1/IRGM": "autophagy_module",
    "XBP1": "upr_module",
    "IL23R": "il23_module",
    "MUC2": "mucus_module",
}

EDGE_FALSIFICATION = {
    "NOD2->ATG16L1/IRGM": "If restoring NOD2 does not recover ATG16L1-dependent xenophagy in isogenic systems.",
    "ATG16L1/IRGM->XBP1": "If autophagy rescue fails to reduce ER-stress signatures in epithelial secretory compartments.",
    "XBP1->IL23R": "If epithelial stress correction does not attenuate IL-23-axis signatures under matched challenge.",
    "IL23R->MUC2": "If sustained IL-23-axis suppression fails to improve mucus-program signatures despite lower inflammatory load.",
    "MUC2->NOD2": "If reducing mucus-program integrity does not increase innate microbial-sensing signatures in matched systems.",
}

PHENOTYPE_COLS = [
    "Ileal localization",
    "Paneth pathology",
    "Stricturing/fibrosis",
    "Anti-TNF response heterogeneity",
    "IL-23 response enrichment",
]


@dataclass
class OTRow:
    symbol: str
    ensembl_id: str
    overall: float
    datatypes: Dict[str, float]
    rank: int


def clamp(x: float, lo: float, hi: float) -> float:
    return float(max(lo, min(hi, x)))


def minmax_scale(series: pd.Series, max_score: float) -> pd.Series:
    s = series.astype(float)
    lo = float(s.min())
    hi = float(s.max())
    if not np.isfinite(lo) or not np.isfinite(hi):
        return pd.Series(np.zeros(len(s)), index=s.index)
    if math.isclose(lo, hi):
        return pd.Series(np.full(len(s), max_score / 2.0), index=s.index)
    return (s - lo) / (hi - lo) * max_score


def safe_spearman(a: pd.Series, b: pd.Series) -> float:
    mask = a.notna() & b.notna()
    if mask.sum() < 3:
        return 0.0
    corr = a[mask].corr(b[mask], method="spearman")
    return 0.0 if pd.isna(corr) else float(corr)


def directional_support(corr: float, expected_sign: str) -> float:
    if expected_sign == "negative":
        return max(0.0, -corr)
    return max(0.0, corr)


def score_label(total: float) -> str:
    if total >= 4.5:
        return "Strong"
    if total >= 2.5:
        return "Moderate"
    return "Preliminary"


def fetch_opentargets_rows(efo_id: str, page_size: int = 500, max_pages: int = 30) -> Dict[str, OTRow]:
    query = """
    query($efoId:String!,$index:Int!,$size:Int!){
      disease(efoId:$efoId){
        associatedTargets(page:{index:$index,size:$size}){
          rows{
            score
            target{approvedSymbol id}
            datatypeScores{id score}
          }
        }
      }
    }
    """

    out: Dict[str, OTRow] = {}
    rank = 1
    for page in range(max_pages):
        payload = {
            "query": query,
            "variables": {"efoId": efo_id, "index": page, "size": page_size},
        }
        resp = requests.post(OT_URL, json=payload, timeout=60)
        resp.raise_for_status()
        rows = resp.json().get("data", {}).get("disease", {}).get("associatedTargets", {}).get("rows", [])
        if not rows:
            break

        for row in rows:
            symbol = row["target"]["approvedSymbol"]
            datatypes = {d["id"]: float(d["score"]) for d in row.get("datatypeScores", [])}
            out[symbol] = OTRow(
                symbol=symbol,
                ensembl_id=row["target"]["id"],
                overall=float(row["score"]),
                datatypes=datatypes,
                rank=rank,
            )
            rank += 1

        if len(rows) < page_size:
            break

    return out


def save_ot_snapshot(rows: Dict[str, OTRow], out_path: Path) -> None:
    records = []
    for row in rows.values():
        records.append(
            {
                "symbol": row.symbol,
                "ensembl_id": row.ensembl_id,
                "rank": row.rank,
                "overall_score": row.overall,
                "datatype_scores_json": json.dumps(row.datatypes, sort_keys=True),
            }
        )
    pd.DataFrame(records).sort_values("rank").to_csv(out_path, index=False)


def fetch_string_score(gene_a: str, gene_b: str, cache: Dict[Tuple[str, str], float]) -> float:
    key = tuple(sorted((gene_a, gene_b)))
    if key in cache:
        return cache[key]

    if gene_a == gene_b:
        cache[key] = 1.0
        return 1.0

    params = {
        "identifiers": f"{gene_a}%0d{gene_b}",
        "species": 9606,
        "required_score": 0,
    }
    score = 0.0
    try:
        resp = requests.get(STRING_URL, params=params, timeout=30)
        resp.raise_for_status()
        rows = resp.json()
        for row in rows:
            a = row.get("preferredName_A")
            b = row.get("preferredName_B")
            if {a, b} == {gene_a, gene_b}:
                score = float(row.get("score", 0.0))
                break
        if score == 0.0 and rows:
            score = float(max(r.get("score", 0.0) for r in rows))
    except Exception:
        score = 0.0

    cache[key] = score
    return score


def module_activity(genes: Iterable[str], vectors: Dict[str, np.ndarray], n_obs: int) -> np.ndarray:
    avail = [vectors[g] for g in genes if g in vectors]
    if not avail:
        return np.full(n_obs, np.nan)
    return np.nanmean(np.vstack(avail), axis=0)


def read_gene_vectors(adata: ad.AnnData, genes: Iterable[str]) -> Tuple[Dict[str, np.ndarray], List[str]]:
    if "gene_symbols" in adata.var.columns:
        symbols = adata.var["gene_symbols"].astype(str).values
    elif "feature_name" in adata.var.columns:
        symbols = adata.var["feature_name"].astype(str).values
    else:
        symbols = adata.var_names.astype(str).values

    vectors: Dict[str, np.ndarray] = {}
    missing: List[str] = []

    for gene in sorted(set(genes)):
        idx = np.where(symbols == gene)[0]
        if len(idx) == 0:
            missing.append(gene)
            continue

        mat = adata[:, idx].X
        if sp.issparse(mat):
            vec = np.asarray(mat.mean(axis=1)).ravel()
        else:
            vec = np.asarray(mat).mean(axis=1).ravel()
        vec = np.log1p(np.clip(vec, a_min=0.0, a_max=None))
        vectors[gene] = vec

    return vectors, missing


def summarize_scrna_dataset(tag: str, path: Path, required_genes: Iterable[str]) -> Tuple[pd.DataFrame, Dict[str, int], List[str]]:
    adata = ad.read_h5ad(path)
    vectors, missing = read_gene_vectors(adata, required_genes)
    n_obs = adata.n_obs

    node_activity = {
        "NOD2": module_activity(["NOD2"], vectors, n_obs),
        "ATG16L1/IRGM": module_activity(["ATG16L1", "IRGM"], vectors, n_obs),
        "XBP1": module_activity(["XBP1"], vectors, n_obs),
        "IL23R": module_activity(["IL23R"], vectors, n_obs),
        "MUC2": module_activity(["MUC2", "AGR2", "SPDEF", "FCGBP"], vectors, n_obs),
    }

    module_activity_map = {
        name: module_activity(genes, vectors, n_obs)
        for name, genes in MODULE_GENES.items()
    }

    disease_col = "disease" if "disease" in adata.obs.columns else "disease_ontology_term_id"
    celltype_col = "cell_type" if "cell_type" in adata.obs.columns else "cell_type_ontology_term_id"
    donor_col = "donor_id" if "donor_id" in adata.obs.columns else ("sample_id" if "sample_id" in adata.obs.columns else None)

    df = pd.DataFrame(
        {
            "dataset": tag,
            "cell_type": adata.obs[celltype_col].astype(str).values,
            "disease": adata.obs[disease_col].astype(str).values,
        }
    )
    if donor_col is not None:
        df["donor_id"] = adata.obs[donor_col].astype(str).values

    for k, v in node_activity.items():
        df[k] = v
    for k, v in module_activity_map.items():
        df[k] = v

    grp = df.groupby(["dataset", "disease", "cell_type"], observed=True).mean(numeric_only=True)
    n = df.groupby(["dataset", "disease", "cell_type"], observed=True).size().rename("n_cells")
    out = grp.join(n).reset_index()
    out = out[out["n_cells"] >= 50].copy()

    disease_counts = df["disease"].value_counts().to_dict()
    return out, disease_counts, missing


def main() -> int:
    missing_inputs = [str(p) for p in SCRNA_DATASETS.values() if not p.exists()]
    if missing_inputs:
        raise FileNotFoundError("Missing required scRNA datasets: " + ", ".join(missing_inputs))

    required_genes = set()
    for genes in MODULE_GENES.values():
        required_genes.update(genes)
    for genes in NODE_GENES.values():
        required_genes.update(genes)

    # 1) Open Targets data
    ot_rows = fetch_opentargets_rows(CROHNS_EFO)
    save_ot_snapshot(ot_rows, DATA_DIR / "opentargets_crohns_associated_targets.csv")
    max_rank = max(r.rank for r in ot_rows.values())

    # 2) scRNA summaries from Crohn/normal TI datasets
    scrna_tables = []
    manifest = []
    missing_genes_union = set()

    for tag, path in SCRNA_DATASETS.items():
        table, disease_counts, missing = summarize_scrna_dataset(tag, path, required_genes)
        scrna_tables.append(table)
        missing_genes_union.update(missing)
        manifest.append(
            {
                "dataset_tag": tag,
                "path": str(path),
                "n_rows_after_filter": len(table),
                "cells_crohn": int(disease_counts.get("Crohn disease", 0)),
                "cells_normal": int(disease_counts.get("normal", 0)),
            }
        )

    ct = pd.concat(scrna_tables, ignore_index=True)
    ct.to_csv(DATA_DIR / "crohn_normal_celltype_means.csv", index=False)
    pd.DataFrame(manifest).to_csv(DATA_DIR / "scRNA_dataset_manifest.csv", index=False)

    crohn_ct = ct[ct["disease"] == "Crohn disease"].copy()
    if crohn_ct.empty:
        raise RuntimeError("No Crohn disease rows found after filtering in scRNA datasets.")

    # Matched Crohn-normal deltas by dataset+cell_type
    value_cols = [
        "NOD2",
        "ATG16L1/IRGM",
        "XBP1",
        "IL23R",
        "MUC2",
        "nod2_module",
        "autophagy_module",
        "upr_module",
        "il23_module",
        "mucus_module",
        "fibrosis_signature",
        "tnf_signature",
        "il23_signature",
    ]
    pivot = ct.pivot_table(index=["dataset", "cell_type"], columns="disease", values=value_cols, aggfunc="first")
    common = pivot.dropna().copy()
    delta = pd.DataFrame(index=common.index)
    for col in value_cols:
        delta[col] = common[(col, "Crohn disease")] - common[(col, "normal")]
    delta.to_csv(DATA_DIR / "crohn_minus_normal_celltype_deltas.csv")

    # 3) Node table
    string_cache: Dict[Tuple[str, str], float] = {}
    node_records = []
    node_raw_genetic = {}
    node_raw_lit = {}

    for node, genes in NODE_GENES.items():
        ot_subset = [ot_rows[g] for g in genes if g in ot_rows]
        if not ot_subset:
            genetic_raw = 0.0
            lit_raw = 0.0
            overall_raw = 0.0
            rank_raw = float("nan")
        else:
            rank_pct = [1.0 - (r.rank - 1) / (max_rank - 1) for r in ot_subset]
            genetic_assoc = [r.datatypes.get("genetic_association", 0.0) for r in ot_subset]
            genetic_raw = float(np.mean([0.5 * ga + 0.5 * rp for ga, rp in zip(genetic_assoc, rank_pct)]))
            lit_raw = float(np.mean([max(r.datatypes.get("genetic_literature", 0.0), r.datatypes.get("literature", 0.0)) for r in ot_subset]))
            overall_raw = float(np.mean([r.overall for r in ot_subset]))
            rank_raw = float(np.mean([r.rank for r in ot_subset]))

        partner_scores = []
        for other, other_genes in NODE_GENES.items():
            if other == node:
                continue
            best = 0.0
            for ga in genes:
                for gb in other_genes:
                    best = max(best, fetch_string_score(ga, gb, string_cache))
            partner_scores.append(best)
        string_raw = float(np.mean(partner_scores)) if partner_scores else 0.0

        cell_activity_raw = float(crohn_ct[node].quantile(0.9))
        disease_shift_raw = float(abs(delta[node].mean())) if node in delta.columns else 0.0

        node_raw_genetic[node] = genetic_raw
        node_raw_lit[node] = lit_raw

        notes = []
        if node == "MUC2" and "MUC2" in missing_genes_union:
            notes.append("MUC2 sparse in some subsets; mucus-module genes used for node activity")
        if not notes:
            notes.append("No major missingness in selected sources")

        node_records.append(
            {
                "node": node,
                "raw_ot_rank": rank_raw,
                "raw_ot_genetic_support": genetic_raw,
                "raw_ot_literature_support": lit_raw,
                "raw_string_network_convergence": string_raw,
                "raw_crohn_cell_activity_q90": cell_activity_raw,
                "raw_crohn_normal_shift_abs": disease_shift_raw,
                "raw_ot_overall_association": overall_raw,
                "evidence_basis": "Open Targets Crohn target evidence + STRING functional coupling + CELLxGENE TI Crohn/normal scRNA",
                "uncertainty_note": "; ".join(notes),
            }
        )

    node_df = pd.DataFrame(node_records).set_index("node")
    node_df["ot_genetic_support_score_0_3"] = (node_df["raw_ot_genetic_support"] * 3.0).apply(lambda x: clamp(x, 0, 3))
    node_df["ot_literature_support_score_0_3"] = (node_df["raw_ot_literature_support"] * 3.0).apply(lambda x: clamp(x, 0, 3))
    node_df["string_network_convergence_score_0_3"] = (node_df["raw_string_network_convergence"] * 3.0).apply(lambda x: clamp(x, 0, 3))
    node_df["crohn_cell_activity_score_0_3"] = minmax_scale(node_df["raw_crohn_cell_activity_q90"], 3.0)
    node_df["ot_overall_association_score_0_2"] = (node_df["raw_ot_overall_association"] * 2.0).apply(
        lambda x: clamp(x, 0, 2)
    )

    node_unc = []
    for node in node_df.index:
        unc = 0.35
        if node == "MUC2" and "MUC2" in missing_genes_union:
            unc += 0.6
        node_unc.append(clamp(unc, 0.35, 2.0))
    node_df["uncertainty_0_2"] = node_unc

    score_cols = [
        "ot_genetic_support_score_0_3",
        "ot_literature_support_score_0_3",
        "string_network_convergence_score_0_3",
        "crohn_cell_activity_score_0_3",
        "ot_overall_association_score_0_2",
    ]
    node_df["total_leverage_score_0_14"] = node_df[score_cols].sum(axis=1)

    node_cols = score_cols + [
        "total_leverage_score_0_14",
        "uncertainty_0_2",
        "evidence_basis",
        "uncertainty_note",
        "raw_ot_rank",
        "raw_ot_genetic_support",
        "raw_ot_literature_support",
        "raw_string_network_convergence",
        "raw_crohn_cell_activity_q90",
        "raw_crohn_normal_shift_abs",
        "raw_ot_overall_association",
    ]
    node_df = node_df[node_cols].sort_values("total_leverage_score_0_14", ascending=False)
    node_df.to_csv(TABLE_DIR / "node_rank_table.csv")

    # 4) Edge table
    edge_records = []

    for up, down in EDGE_ORDER:
        edge = f"{up}->{down}"
        expected = EDGE_EXPECTED_SIGN[edge]

        module_up = EDGE_MODULE_MAP[up]
        module_down = EDGE_MODULE_MAP[down]

        corr_crohn = safe_spearman(crohn_ct[module_up], crohn_ct[module_down])
        corr_delta = safe_spearman(delta[module_up], delta[module_down]) if not delta.empty else 0.0
        disease_coupling_raw = float(
            np.mean(
                [
                    directional_support(corr_crohn, expected),
                    directional_support(corr_delta, expected),
                ]
            )
        )

        genetic_pair_raw = min(node_raw_genetic.get(up, 0.0), node_raw_genetic.get(down, 0.0))
        literature_pair_raw = float(np.mean([node_raw_lit.get(up, 0.0), node_raw_lit.get(down, 0.0)]))

        up_genes = NODE_GENES[up]
        down_genes = NODE_GENES[down]
        string_raw = 0.0
        for ga in up_genes:
            for gb in down_genes:
                string_raw = max(string_raw, fetch_string_score(ga, gb, string_cache))

        genetics_score = clamp(genetic_pair_raw * 2.0, 0, 2)
        coupling_score = clamp(disease_coupling_raw * 2.0, 0, 2)
        string_score = clamp(string_raw * 2.0, 0, 2)
        lit_score = clamp(literature_pair_raw * 2.0, 0, 2)
        total = genetics_score + coupling_score + string_score + lit_score

        edge_records.append(
            {
                "edge": edge,
                "genetic_pair_support_score_0_2": genetics_score,
                "disease_state_coupling_score_0_2": coupling_score,
                "string_functional_coupling_score_0_2": string_score,
                "literature_pair_support_score_0_2": lit_score,
                "total_score_0_8": total,
                "evidence_label": score_label(total),
                "falsification_test": EDGE_FALSIFICATION[edge],
                "raw_genetic_pair": genetic_pair_raw,
                "raw_crohn_corr": corr_crohn,
                "raw_delta_corr": corr_delta,
                "raw_disease_coupling": disease_coupling_raw,
                "raw_string": string_raw,
                "raw_literature_pair": literature_pair_raw,
            }
        )

    edge_df = pd.DataFrame(edge_records)
    edge_df.to_csv(TABLE_DIR / "edge_evidence_scores.csv", index=False)

    # 5) Phenotype mapping from real module profiles
    ct_idx = crohn_ct["cell_type"].astype(str)
    epi_mask = ct_idx.str.contains("enterocyte|goblet|paneth|stem|transit|progenitor|epithelial", case=False, regex=True)
    paneth_mask = ct_idx.str.contains("paneth", case=False, regex=True)

    pheno_raw_rows = []
    for node in NODE_GENES.keys():
        node_profile = crohn_ct[node]

        epi_mean = float(node_profile[epi_mask].mean()) if epi_mask.any() else float("nan")
        non_epi_mean = float(node_profile[~epi_mask].mean()) if (~epi_mask).any() else float("nan")
        ileal_raw = epi_mean - non_epi_mean

        paneth_raw = float(node_profile[paneth_mask].mean()) if paneth_mask.any() else float("nan")
        if node in delta.columns:
            delta_paneth = delta.reset_index()
            delta_paneth = delta_paneth[delta_paneth["cell_type"].astype(str).str.contains("paneth", case=False, regex=True)]
            if not delta_paneth.empty:
                paneth_raw += float(delta_paneth[node].mean())

        # Fibrosis / TNF / IL23 mapping using both Crohn profile and Crohn-normal deltas
        crohn_fib = directional_support(safe_spearman(node_profile, crohn_ct["fibrosis_signature"]), "positive")
        delta_fib = directional_support(safe_spearman(delta[node], delta["fibrosis_signature"]), "positive") if node in delta.columns else 0.0

        crohn_tnf = abs(safe_spearman(node_profile, crohn_ct["tnf_signature"]))
        delta_tnf = abs(safe_spearman(delta[node], delta["tnf_signature"])) if node in delta.columns else 0.0

        crohn_il23 = directional_support(safe_spearman(node_profile, crohn_ct["il23_signature"]), "positive")
        delta_il23 = directional_support(safe_spearman(delta[node], delta["il23_signature"]), "positive") if node in delta.columns else 0.0

        pheno_raw_rows.append(
            {
                "node": node,
                "Ileal localization": ileal_raw,
                "Paneth pathology": paneth_raw,
                "Stricturing/fibrosis": float(np.mean([crohn_fib, delta_fib])),
                "Anti-TNF response heterogeneity": float(np.mean([crohn_tnf, delta_tnf])),
                "IL-23 response enrichment": float(np.mean([crohn_il23, delta_il23])),
            }
        )

    pheno_raw = pd.DataFrame(pheno_raw_rows).set_index("node")
    pheno_raw.to_csv(TABLE_DIR / "phenotype_mapping_raw_metrics.csv")

    pheno_scaled = pd.DataFrame(index=pheno_raw.index)
    for col in PHENOTYPE_COLS:
        pheno_scaled[col] = np.rint(minmax_scale(pheno_raw[col], 3.0)).astype(int)
    pheno_scaled.to_csv(TABLE_DIR / "phenotype_mapping_scores.csv", index_label="node")

    # 6) Save STRING cache + missing genes for reproducibility
    pd.DataFrame(
        [{"gene_a": a, "gene_b": b, "string_score": s} for (a, b), s in sorted(string_cache.items())]
    ).to_csv(DATA_DIR / "string_pair_scores.csv", index=False)

    with open(DATA_DIR / "missing_genes_in_scrna.json", "w", encoding="utf-8") as f:
        json.dump(sorted(missing_genes_union), f, indent=2)

    print(f"Wrote {TABLE_DIR / 'node_rank_table.csv'}")
    print(f"Wrote {TABLE_DIR / 'edge_evidence_scores.csv'}")
    print(f"Wrote {TABLE_DIR / 'phenotype_mapping_scores.csv'}")
    print(f"Wrote intermediates in {DATA_DIR}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
