"""
Fano Engine — Core math extracted from Φ-BETA v12.4.1
Wraps the validated math for API consumption.
L0-L6 logic is IDENTICAL to phi_beta_v12_4_1.py — DO NOT MODIFY THE MATH.
"""

import os
import numpy as np
from scipy.spatial import KDTree
from Bio.PDB import PDBList, PDBParser
import warnings
import logging

warnings.filterwarnings("ignore")
logging.getLogger("Bio").setLevel(logging.ERROR)

PHI = (1 + np.sqrt(5)) / 2
PI = np.pi
FANO_CLASSES = 7
DEFAULT_N = 4 / 3
DUAL_LOCK_THRESHOLD = 6e-3

BIO_CLASS_MAP = {
    0: "L0 Hydrophilic Surface",
    1: "L1 Phospholipid Layer",
    2: "L2 Cholesterol Domain",
    3: "L3 Protein Pocket",
    4: "L4 Glycoprotein Surface",
    5: "L5 Lipid Raft",
    6: "L6 Hydrophobic Core",
}

BIO_FUNCTION = {
    0: "Binding initiation",
    1: "Membrane anchoring",
    2: "Rigidity & packing",
    3: "Enzymatic catalysis",
    4: "Signal recognition",
    5: "Localization & transport",
    6: "Folding stabilization",
}

EXPECTED_CLASSES = {
    "catalytic": [2, 3, 4],
    "binding": [0, 1, 4],
    "structural": [5, 6],
}

CSA_TARGETS = {
    "1BTL": {"active_sites": [70, 73, 130, 166], "func_type": "catalytic", "name": "TEM-1 Beta-Lactamase"},
    "1ACB": {"active_sites": [42, 105, 195], "func_type": "catalytic", "name": "Chymotrypsin"},
    "2CMD": {"active_sites": [195], "func_type": "catalytic", "name": "Malate Dehydrogenase"},
    "1LYZ": {"active_sites": [35, 52], "func_type": "catalytic", "name": "Lysozyme"},
    "1RTB": {"active_sites": [2, 10, 71, 106], "func_type": "catalytic", "name": "HIV-1 Protease"},
    "4ENL": {"active_sites": [156, 211, 245, 295], "func_type": "catalytic", "name": "Enolase"},
    "1CA2": {"active_sites": [94, 96, 119, 199], "func_type": "catalytic", "name": "Carbonic Anhydrase II"},
    "1TIM": {"active_sites": [12, 95, 165], "func_type": "catalytic", "name": "Triosephosphate Isomerase"},
    "1BRS": {"active_sites": [73, 102], "func_type": "catalytic", "name": "Barnase"},
    "7RSA": {"active_sites": [12, 41, 119], "func_type": "catalytic", "name": "Ribonuclease A"},
    "1GPD": {"active_sites": [149], "func_type": "catalytic", "name": "GAPDH"},
    "1AMU": {"active_sites": [21, 165, 269], "func_type": "catalytic", "name": "Urease"},
    "1A22": {"active_sites": [92, 117, 170], "func_type": "binding", "name": "HGH Receptor"},
    "1MUW": {"active_sites": [225, 327, 432], "func_type": "catalytic", "name": "Acetylcholinesterase"},
    "1LDN": {"active_sites": [102, 109, 195, 199], "func_type": "catalytic", "name": "Lactate Dehydrogenase"},
    "1CSE": {"active_sites": [32, 64, 221], "func_type": "catalytic", "name": "Subtilisin Carlsberg"},
}

# In-memory cache for parsed PDB structures
_pdb_cache = {}


def compute_omega_field(alpha_array, n=DEFAULT_N, lambda_nm=400):
    """
    Full Fano field computation per v8.2.
    Class assigned by alpha-binning: k = floor(7*alpha) % 7.
    Sigma computed with class-specific refraction shift.
    IDENTICAL to phi_beta_v12_4_1.py — DO NOT MODIFY.
    """
    a = np.clip(alpha_array, 1e-6, 1.0 - 1e-6)
    k_idx = (np.floor(FANO_CLASSES * a).astype(int)) % FANO_CLASSES

    n_res = len(a)
    Sigma = np.zeros(n_res)
    I_rel = np.zeros(n_res)
    Dmin = np.zeros(n_res)
    stable = np.zeros(n_res, dtype=bool)

    for k in range(FANO_CLASSES):
        mask = k_idx == k
        if not np.any(mask):
            continue
        seg = a[mask]

        i_ = np.arcsin(seg)
        r_ = np.arcsin(np.clip(np.sin(i_) / n, -0.999999, 0.999999)) + (k * PI / 21)
        D = np.clip(PI - 2 * r_ - 2 * i_, 1e-6, None)

        r_k = PI * np.log(D) / np.log(PHI)
        Gphi = (PHI ** (r_k / PI)) * (1 + k / FANO_CLASSES)
        eps_phi = np.abs(D - Gphi)
        eps_pi = np.abs(D * PI - np.round(D * PI))
        S = eps_phi * eps_pi

        eps_min = np.sqrt(lambda_nm / 1.0 * 1e-3)
        I = 1.0 / (np.sqrt(seg * (1 - seg)) * np.sqrt(eps_min))

        Sigma[mask] = S
        I_rel[mask] = I
        Dmin[mask] = D
        stable[mask] = S < DUAL_LOCK_THRESHOLD

    return k_idx, Sigma, I_rel, Dmin, stable


def calculate_radial_alpha(coords, R=1.0):
    """Radial alpha with spatial scale parameter R.
    R > 1 compresses manifold (more core), R < 1 extends it (more surface).
    """
    centroid = np.mean(coords, axis=0)
    distances = np.linalg.norm(coords - centroid, axis=1)
    max_dist = np.max(distances)
    if max_dist < 1e-6:
        return np.ones(len(coords)) * 0.5
    rho = distances / (max_dist * R)
    return np.clip(1.0 - rho, 0.0, 1.0)


def calculate_local_fano_alpha(coords):
    """7-NN local manifold alpha — IDENTICAL to phi_beta_v12_4_1.py."""
    if len(coords) < 8:
        return np.ones(len(coords)) * 0.5
    tree = KDTree(coords)
    dist, idx = tree.query(coords, k=8)
    alphas = np.zeros(len(coords))
    for i in range(len(coords)):
        vecs = coords[idx[i, 1:]] - coords[i]
        norms = np.linalg.norm(vecs, axis=1)
        valid = norms > 1e-6
        if not np.any(valid):
            continue
        unit = vecs[valid] / norms[valid, None]
        centroid_norm = np.linalg.norm(np.mean(unit, axis=0))
        alphas[i] = np.clip(1.0 - centroid_norm, 0.0, 1.0)
    return alphas


def fetch_and_parse_pdb(pdb_id, workdir="/tmp"):
    """Download and parse PDB, caching the result."""
    if pdb_id in _pdb_cache:
        return _pdb_cache[pdb_id]

    try:
        pdbl = PDBList()
        filename = pdbl.retrieve_pdb_file(pdb_id, pdir=workdir, file_format="pdb")
    except Exception:
        return None

    parser = PDBParser(QUIET=True)
    try:
        structure = parser.get_structure(pdb_id, filename)
    except Exception:
        return None

    res_ids, coords, b_factors = [], [], []
    for model in structure:
        for chain in model:
            for res in chain:
                if "CA" in res:
                    res_ids.append(res.get_id()[1])
                    coords.append(res["CA"].coord.tolist())
                    b_factors.append(float(res["CA"].get_bfactor()))
            break
        break

    if not coords:
        return None

    result = {
        "res_ids": res_ids,
        "coords": coords,
        "b_factors": b_factors,
    }
    _pdb_cache[pdb_id] = result

    try:
        if os.path.exists(filename):
            os.remove(filename)
    except Exception:
        pass

    return result


def compute_fano_field(pdb_id, n=DEFAULT_N, R=1.0, lambda_nm=400):
    """Full computation pipeline for the interactive tool."""
    target = CSA_TARGETS.get(pdb_id)
    if not target:
        return None

    parsed = fetch_and_parse_pdb(pdb_id)
    if not parsed:
        return None

    coords = np.array(parsed["coords"])
    res_ids = parsed["res_ids"]
    b_factors = parsed["b_factors"]
    active_sites = target["active_sites"]
    func_type = target["func_type"]

    radial_alpha = calculate_radial_alpha(coords, R=R)
    local_alpha = calculate_local_fano_alpha(coords)
    k_idx, sigma, intensity, dmin, stable = compute_omega_field(radial_alpha, n=n, lambda_nm=lambda_nm)

    residues = []
    for i in range(len(res_ids)):
        residues.append({
            "res_id": int(res_ids[i]),
            "radial_alpha": round(float(radial_alpha[i]), 4),
            "local_alpha": round(float(local_alpha[i]), 4),
            "fano_k": int(k_idx[i]),
            "fano_class": BIO_CLASS_MAP[int(k_idx[i])],
            "bio_function": BIO_FUNCTION[int(k_idx[i])],
            "sigma": float(sigma[i]),
            "dmin": round(float(dmin[i]), 4),
            "intensity": round(float(intensity[i]), 4),
            "stable": bool(stable[i]),
            "b_factor": round(float(b_factors[i]), 2),
            "is_active": int(res_ids[i]) in active_sites,
        })

    total_active = sum(1 for r in residues if r["is_active"])
    class_stats = {}
    for k in range(FANO_CLASSES):
        members = [r for r in residues if r["fano_k"] == k]
        active_in = [r for r in members if r["is_active"]]
        n_m = len(members)
        frac = n_m / len(residues) if residues else 0
        hit = len(active_in) / total_active if total_active > 0 else 0
        enrich = hit / frac if frac > 0 else 0
        n_stable = sum(1 for r in members if r["stable"])
        class_stats[str(k)] = {
            "label": BIO_CLASS_MAP[k],
            "function": BIO_FUNCTION[k],
            "count": n_m,
            "fraction": round(frac, 4),
            "active_count": len(active_in),
            "enrichment": round(enrich, 2),
            "stable_count": n_stable,
        }

    expected_ks = EXPECTED_CLASSES.get(func_type, [2, 3, 4])
    pred_members = [r for r in residues if r["fano_k"] in expected_ks]
    pred_active = [r for r in pred_members if r["is_active"]]
    pred_frac = len(pred_members) / len(residues) if residues else 0
    pred_hit = len(pred_active) / total_active if total_active > 0 else 0
    pred_enrichment = pred_hit / pred_frac if pred_frac > 0 else 0

    best_k = max(range(FANO_CLASSES), key=lambda k: class_stats[str(k)]["enrichment"])

    return {
        "pdb_id": pdb_id,
        "name": target["name"],
        "func_type": func_type,
        "n": n,
        "R": R,
        "lambda_nm": lambda_nm,
        "total_residues": len(residues),
        "total_active": total_active,
        "n_stable": sum(1 for r in residues if r["stable"]),
        "predicted_classes": expected_ks,
        "predicted_enrichment": round(pred_enrichment, 2),
        "best_class": best_k,
        "best_class_name": BIO_CLASS_MAP[best_k],
        "best_enrichment": round(class_stats[str(best_k)]["enrichment"], 2),
        "mean_radial_alpha": round(float(np.mean(radial_alpha)), 4),
        "mean_local_alpha": round(float(np.mean(local_alpha)), 4),
        "residues": residues,
        "class_stats": class_stats,
        "active_sites": active_sites,
    }
