"""
simulation_code.py

Reproducibility code for:

Emergent Spacetime from Relational Coherence Dynamics:
A Minimal Information-Theoretic Framework with Numerical Evidence

Author:
Manabu Murashita
Independent Researcher, Osaka, Japan
ORCID: https://orcid.org/0009-0002-0860-2009

Version:
Preprint package v1.0

Purpose:
This script reproduces the numerical simulation framework and representative
figures described in the manuscript. It generates:

- Figure 2: Temporal evolution of the coherence field Phi_i
- Figure 3: Variance evolution of the coherence field
- Figure 4: Exploratory MDS embedding of the relational distance matrix,
            colored by final coherence values Phi_i
- Optional control: MDS embedding with shuffled final coherence colors

Important sign convention:
The numerical implementation uses the coupling term

    sum_j K_ij * (Phi_i - Phi_j)

so that the evolution equation implemented here is

    dPhi_i/dt = alpha * Phi_i - beta * Phi_i^3
                + kappa * sum_j K_ij * (Phi_i - Phi_j)

This sign convention is the one used for the figures and numerical values
reported in the preprint package.
"""

from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import MDS
from sklearn.metrics import pairwise_distances


def run_simulation(
    N: int = 100,
    alpha: float = 1.0,
    beta: float = 1.0,
    kappa: float = 0.1,
    dt: float = 0.01,
    steps: int = 500,
    seed: int = 0,
):
    """Run the relational coherence simulation.

    Parameters
    ----------
    N:
        Number of nodes.
    alpha:
        Linear local amplification coefficient.
    beta:
        Nonlinear saturation coefficient.
    kappa:
        Coupling strength.
    dt:
        Time step.
    steps:
        Number of update steps.
    seed:
        Random seed for reproducibility.

    Returns
    -------
    history:
        Array with shape (steps, N), containing Phi_i at each time step.
    K:
        Symmetric relational weight matrix.
    """

    np.random.seed(seed)

    Phi = np.random.normal(0, 0.1, N)

    K = np.random.rand(N, N)
    K = (K + K.T) / 2
    np.fill_diagonal(K, 0.0)

    history = []

    for _ in range(steps):
        # Sign convention used in the manuscript's numerical implementation:
        # coupling_i = sum_j K_ij * (Phi_i - Phi_j)
        coupling = np.sum(K * (Phi[:, None] - Phi[None, :]), axis=1)
        dPhi = alpha * Phi - beta * Phi**3 + kappa * coupling
        Phi = Phi + dt * dPhi
        history.append(Phi.copy())

    return np.array(history), K


def compute_mds_embedding(K, epsilon: float = 1e-5, random_state: int = 0):
    """Construct relational distance matrix and compute 2D MDS embedding."""

    N = K.shape[0]

    D = 1.0 / (K + epsilon)
    np.fill_diagonal(D, 0.0)

    mds = MDS(
        n_components=2,
        dissimilarity="precomputed",
        random_state=random_state,
        n_init=8,
        max_iter=1000,
    )

    coords = mds.fit_transform(D)

    D_emb = pairwise_distances(coords)
    mask = ~np.eye(N, dtype=bool)

    raw_stress = np.sum((D[mask] - D_emb[mask]) ** 2)
    normalized_stress = np.sqrt(raw_stress / np.sum(D[mask] ** 2))

    return D, coords, raw_stress, normalized_stress


def save_figure_2(history, output_dir: Path):
    """Save Figure 2: temporal evolution of the coherence field."""

    fig_path = output_dir / "Figure2_Coherence_Evolution_consistent.png"

    plt.figure(figsize=(6, 4))
    plt.imshow(history.T, aspect="auto", cmap="viridis")
    plt.colorbar(label=r"$\Phi_i$")
    plt.xlabel("Time step")
    plt.ylabel("Node")
    plt.title("Temporal Evolution of Coherence Field")
    plt.tight_layout()
    plt.savefig(fig_path, dpi=300)
    plt.close()

    return fig_path


def save_figure_3(history, output_dir: Path):
    """Save Figure 3: variance evolution of the coherence field."""

    fig_path = output_dir / "Figure3_Variance_Evolution_consistent.png"

    variance = np.var(history, axis=1)

    plt.figure(figsize=(6, 4))
    plt.plot(variance)
    plt.xlabel("Time step")
    plt.ylabel("Variance")
    plt.title("Variance of Coherence Field")
    plt.tight_layout()
    plt.savefig(fig_path, dpi=300)
    plt.close()

    return fig_path, variance


def save_figure_4(coords, final_phi, output_dir: Path):
    """Save Figure 4: MDS embedding colored by final coherence values."""

    fig_path = output_dir / "Figure4_MDS_colored_by_Phi_exploratory.png"

    plt.figure(figsize=(6, 5.5))
    sc = plt.scatter(coords[:, 0], coords[:, 1], c=final_phi, cmap="viridis", s=40)
    plt.xlabel("MDS dimension 1")
    plt.ylabel("MDS dimension 2")
    plt.title("Relational Distance Embedding Colored by Final Coherence")
    cbar = plt.colorbar(sc)
    cbar.set_label(r"Final coherence $\Phi_i$")
    plt.tight_layout()
    plt.savefig(fig_path, dpi=300)
    plt.close()

    return fig_path


def save_shuffled_control(coords, final_phi, output_dir: Path, shuffle_seed: int = 12345):
    """Save optional shuffled-coherence control figure.

    The MDS coordinates are unchanged. Only the node colors are randomly shuffled.
    """

    out_path = output_dir / "Figure4_MDS_shuffled_Phi_control.png"

    rng = np.random.default_rng(shuffle_seed)
    shuffled_phi = rng.permutation(final_phi)

    plt.figure(figsize=(6, 5.5))
    sc = plt.scatter(coords[:, 0], coords[:, 1], c=shuffled_phi, cmap="viridis", s=40)
    plt.xlabel("MDS dimension 1")
    plt.ylabel("MDS dimension 2")
    plt.title("Relational Distance Embedding with Shuffled Coherence Colors")
    cbar = plt.colorbar(sc)
    cbar.set_label(r"Shuffled coherence $\Phi_i$")
    plt.tight_layout()
    plt.savefig(out_path, dpi=300)
    plt.close()

    return out_path


def main():
    output_dir = Path("outputs")
    output_dir.mkdir(exist_ok=True)

    history, K = run_simulation()

    final_phi = history[-1]
    D, coords, raw_stress, normalized_stress = compute_mds_embedding(K)

    fig2_path = save_figure_2(history, output_dir)
    fig3_path, variance = save_figure_3(history, output_dir)
    fig4_path = save_figure_4(coords, final_phi, output_dir)
    control_path = save_shuffled_control(coords, final_phi, output_dir)

    print(f"Figure 2 saved to: {fig2_path}")
    print(f"Figure 3 saved to: {fig3_path}")
    print(f"Figure 4 saved to: {fig4_path}")
    print(f"Shuffled-control figure saved to: {control_path}")
    print(f"Raw MDS stress: {raw_stress:.4f}")
    print(f"Normalized MDS stress: {normalized_stress:.4f}")
    print(f"Final Phi min/max: {final_phi.min():.4f} / {final_phi.max():.4f}")
    print(f"Final variance: {variance[-1]:.4f}")


if __name__ == "__main__":
    main()
