#!/usr/bin/env python3
from pathlib import Path
import csv
import math
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.patheffects as pe
import numpy as np

ROOT = Path(__file__).resolve().parents[2]
FIG_DIR = ROOT / 'figures'
TAB_DIR = ROOT / 'tables'
FIG_DIR.mkdir(parents=True, exist_ok=True)

# Visual defaults
plt.rcParams.update({
    'figure.dpi': 180,
    'font.size': 10,
    'axes.spines.top': False,
    'axes.spines.right': False,
})


def read_csv(path):
    with open(path, newline='', encoding='utf-8') as f:
        return list(csv.DictReader(f))


# Figure 1: Circuit schematic
fig, ax = plt.subplots(figsize=(10.5, 4.2))
ax.set_xlim(0, 10)
ax.set_ylim(0, 4)
ax.axis('off')

nodes = {
    'NOD2': (1.1, 2.0),
    'ATG16L1/IRGM': (3.0, 2.0),
    'XBP1': (4.9, 2.0),
    'IL23R': (6.8, 2.0),
    'MUC2': (8.7, 2.0),
}
colors = {
    'NOD2': '#1f4e79',
    'ATG16L1/IRGM': '#2e75b6',
    'XBP1': '#5b9bd5',
    'IL23R': '#f28e2b',
    'MUC2': '#8c564b',
}

for name, (x, y) in nodes.items():
    w, h = (1.42, 0.82)
    box = patches.FancyBboxPatch((x - w / 2, y - h / 2), w, h,
                                 boxstyle='round,pad=0.03,rounding_size=0.08',
                                 linewidth=1.5, edgecolor='black', facecolor=colors[name])
    ax.add_patch(box)
    ax.text(x, y, name, ha='center', va='center', color='white', fontsize=10, fontweight='bold')

# Directed edges
order = ['NOD2', 'ATG16L1/IRGM', 'XBP1', 'IL23R', 'MUC2']
for a, b in zip(order[:-1], order[1:]):
    xa, ya = nodes[a]
    xb, yb = nodes[b]
    ax.annotate('', xy=(xb - 0.82, yb), xytext=(xa + 0.82, ya),
                arrowprops=dict(arrowstyle='-|>', lw=2.0, color='black'))

# Feedback loop MUC2 -> NOD2
x1, y1 = nodes['MUC2']
x0, y0 = nodes['NOD2']
arrow = patches.FancyArrowPatch((x1, y1 - 0.45), (x0, y0 - 0.45),
                                connectionstyle='arc3,rad=-0.36',
                                arrowstyle='-|>', mutation_scale=14,
                                linewidth=2.0, color='#444444')
ax.add_patch(arrow)
ax.text(5.0, 0.68, 'Barrier-to-innate sensing feedback', ha='center', va='center', fontsize=9, color='#333333')

ax.text(5.0, 3.55, 'Figure 1. Minimal five-node causal circuit for ileal-predominant Crohn\'s disease',
        ha='center', va='center', fontsize=12, fontweight='bold')
ax.text(5.0, 3.18,
        'Directional hypothesis: NOD2 -> ATG16L1/IRGM -> XBP1 -> IL23R -> MUC2, with positive feedback to innate sensing.',
        ha='center', va='center', fontsize=9)

fig.tight_layout()
fig.savefig(FIG_DIR / 'fig1_minimal_circuit.png', bbox_inches='tight')
fig.savefig(FIG_DIR / 'fig1_minimal_circuit.svg', bbox_inches='tight')
plt.close(fig)


# Figure 2: Node leverage ranking
rows = read_csv(TAB_DIR / 'node_rank_table.csv')
rows = sorted(rows, key=lambda r: float(r['total_leverage_score_0_14']), reverse=True)
nodes = [r['node'] for r in rows]
totals = [float(r['total_leverage_score_0_14']) for r in rows]
unc = [max(0.25, float(r.get('uncertainty_0_2', 0.5)) * 0.8) for r in rows]

fig, ax = plt.subplots(figsize=(8.2, 4.4))
y = np.arange(len(nodes))
colors = ['#1f4e79', '#2e75b6', '#f28e2b', '#5b9bd5', '#8c564b']
ax.barh(y, totals, xerr=unc, color=colors, edgecolor='black', alpha=0.9)
ax.set_yticks(y)
ax.set_yticklabels(nodes)
ax.invert_yaxis()
# Leave room for end-of-whisker labels so numbers never collide with bars/error lines.
x_max = max(t + u for t, u in zip(totals, unc)) + 1.3
ax.set_xlim(0, x_max)
ax.set_xlabel('Node leverage score (0-14)')
ax.set_title('Figure 2. Node prioritization from genetics, pathway convergence, and translational anchor')

for yi, score, err in zip(y, totals, unc):
    label_x = score + err + 0.2
    ax.text(
        label_x, yi, f'{score:.1f}',
        va='center', ha='left', fontsize=9, fontweight='bold',
        bbox=dict(facecolor='white', edgecolor='none', boxstyle='round,pad=0.12', alpha=0.85)
    )

ax.grid(axis='x', linestyle='--', alpha=0.35)
fig.tight_layout()
fig.savefig(FIG_DIR / 'fig2_node_leverage.png', bbox_inches='tight')
fig.savefig(FIG_DIR / 'fig2_node_leverage.svg', bbox_inches='tight')
plt.close(fig)


# Figure 3: Edge evidence heatmap
erows = read_csv(TAB_DIR / 'edge_evidence_scores.csv')
edges = [r['edge'] for r in erows]
mat = np.array([
    [
        float(r['genetic_pair_support_score_0_2']),
        float(r['disease_state_coupling_score_0_2']),
        float(r['string_functional_coupling_score_0_2']),
        float(r['literature_pair_support_score_0_2'])
    ]
    for r in erows
])
cols = ['Genetic pair', 'Disease coupling', 'STRING coupling', 'Literature']

fig, ax = plt.subplots(figsize=(8.8, 4.8))
im = ax.imshow(mat, cmap='YlGnBu', aspect='auto', vmin=0, vmax=2)
ax.set_xticks(np.arange(len(cols)))
ax.set_xticklabels(cols, rotation=20, ha='right')
ax.set_yticks(np.arange(len(edges)))
ax.set_yticklabels(edges)
ax.set_title('Figure 3. Edge-level evidence matrix for the directed circuit')

for i in range(mat.shape[0]):
    for j in range(mat.shape[1]):
        val = mat[i, j]
        rgba = im.cmap(im.norm(val))
        # Perceived luminance for adaptive contrast.
        luminance = 0.2126 * rgba[0] + 0.7152 * rgba[1] + 0.0722 * rgba[2]
        txt_color = 'white' if luminance < 0.5 else 'black'
        outline = 'black' if txt_color == 'white' else 'white'
        t = ax.text(
            j, i, f'{val:.0f}',
            ha='center', va='center', color=txt_color, fontsize=9, fontweight='bold'
        )
        # Thin opposite-color stroke keeps numerals visible across mid-tone cells.
        t.set_path_effects([pe.withStroke(linewidth=1.2, foreground=outline, alpha=0.9)])

cbar = fig.colorbar(im, ax=ax, fraction=0.03, pad=0.03)
cbar.set_label('Evidence subscore (0-2)')
fig.tight_layout()
fig.savefig(FIG_DIR / 'fig3_edge_evidence_heatmap.png', bbox_inches='tight')
fig.savefig(FIG_DIR / 'fig3_edge_evidence_heatmap.svg', bbox_inches='tight')
plt.close(fig)


# Figure 4: Clinical phenotype mapping
prows = read_csv(TAB_DIR / 'phenotype_mapping_scores.csv')
node_list = [r['node'] for r in prows]
phenotypes = [k for k in prows[0].keys() if k != 'node']
score_map = np.array([[float(r[p]) for p in phenotypes] for r in prows], dtype=float)
phenotype_labels = [
    p.replace(' response ', ' response\n').replace(' enrichment', '\nenrichment')
    for p in phenotypes
]

fig, ax = plt.subplots(figsize=(9.2, 4.8))
im = ax.imshow(score_map, cmap='OrRd', aspect='auto', vmin=0, vmax=3)
ax.set_xticks(np.arange(len(phenotypes)))
ax.set_xticklabels(phenotype_labels, rotation=25, ha='right')
ax.set_yticks(np.arange(len(node_list)))
ax.set_yticklabels(node_list)
ax.set_title('Figure 4. Node-to-phenotype mapping from real single-cell module profiles (ordinal support)')

for i in range(score_map.shape[0]):
    for j in range(score_map.shape[1]):
        ax.text(j, i, f'{int(score_map[i, j])}', ha='center', va='center', color='black', fontsize=9)

cbar = fig.colorbar(im, ax=ax, fraction=0.03, pad=0.03)
cbar.set_label('Support level (0-3)')
fig.tight_layout()
fig.savefig(FIG_DIR / 'fig4_phenotype_mapping.png', bbox_inches='tight')
fig.savefig(FIG_DIR / 'fig4_phenotype_mapping.svg', bbox_inches='tight')
plt.close(fig)

print('Generated fig1-fig4 from table inputs.')
