#!/usr/bin/env python3
"""Prototype simulation to evaluate the vibration-training claims from md/idea.md."""

import math
import numpy as np
from modes import available_modes, describe_mode, profile_for_mode

params = {
    "E0": 2.0,
    "E1": 0.5,
    "E2": 0.4,
    "kb": 5.0e-4,
    "kr": 4.0e-5,
    "k0": 1.0e-6,
    "k0psi": 4.0e-6,
    "Eb": 10e3,
    "Er": 8e3,
    "ko": 5.0e-2,
    "kc": 5.0e-3,
    "Eo": 5e3,
    "Ec": 8e3,
    "alpha": 0.1,
    "beta": 1.2,
    "gamma": 0.8,
    "phi_inf_shift": 36.0,
    "phi0": 0.85,
    "fc": 80.0,
    "R": 8.314,
    "T": 298.0,
}


def phi_inf(A, f, p=params):
    if A <= 0 or f <= 0:
        return 0.5
    exponent = p["alpha"] * A**p["beta"] * f**p["gamma"]
    return 1.0 / (1.0 + math.exp(exponent - p["phi_inf_shift"]))


def softening_deriv(phi, A, f, p=params):
    if A <= 0:
        return p["k0"] * (p["phi0"] - phi)

    kb = p["kb"] * A**2 * math.exp(-p["Eb"] / (p["R"] * p["T"]))
    kr = p["kr"] * A**2 * math.exp(-p["Er"] / (p["R"] * p["T"]))
    if f > p["fc"]:
        kb = 0.0
    return kb * (phi_inf(A, f, p) - phi) - kr * phi


def extended_deriv(y, A, f, p=params):
    phi, psi = y
    dphi = softening_deriv(phi, A, f, p)
    if A > 0 and f > p["fc"]:
        ko = p["ko"] * A**2 * math.exp(-p["Eo"] / (p["R"] * p["T"]))
        kc = p["kc"] * A**2 * math.exp(-p["Ec"] / (p["R"] * p["T"]))
        dpsi = ko * (1.0 - psi) - kc * psi
    else:
        dpsi = -p["k0psi"] * psi
    return np.array([dphi, dpsi])


def rk4_step(y, deriv, dt, A, f, p):
    k1 = deriv(y, A, f, p)
    k2 = deriv(y + 0.5 * dt * k1, A, f, p)
    k3 = deriv(y + 0.5 * dt * k2, A, f, p)
    k4 = deriv(y + dt * k3, A, f, p)
    return y + (dt / 6.0) * (k1 + 2.0 * k2 + 2.0 * k3 + k4)


def simulate_segment(y0, duration, A, f, deriv, dt=0.2, p=params):
    n = max(1, int(math.ceil(duration / dt)))
    t = np.linspace(0.0, duration, n + 1)
    y = np.empty((len(t), len(np.atleast_1d(y0))))
    y[0] = y0
    for i in range(len(t) - 1):
        y[i + 1] = rk4_step(y[i], deriv, t[i + 1] - t[i], A, f, p)
    return t, y


def simulate_profile(y0, profile, deriv, dt=0.2, p=params):
    t_total = [0.0]
    y_total = [np.atleast_1d(y0).copy()]
    time_offset = 0.0
    for duration, A, f in profile:
        t_seg, y_seg = simulate_segment(y_total[-1], duration, A, f, deriv, dt, p)
        if len(t_seg) > 1:
            t_seg = t_seg[1:]
            y_seg = y_seg[1:]
        t_total.extend(time_offset + t_seg)
        y_total.extend(y_seg)
        time_offset = t_total[-1]
    return np.array(t_total), np.vstack(y_total)


def effective_modulus(phi, psi=None, p=params):
    if psi is None:
        return p["E0"] + p["E1"] * phi
    return p["E0"] + p["E1"] * phi + p["E2"] * psi


def pulse_coefficients(A, f, p=params):
    kb_eff = p["kb"] * A**2 * math.exp(-p["Eb"] / (p["R"] * p["T"]))
    kr_eff = p["kr"] * A**2 * math.exp(-p["Er"] / (p["R"] * p["T"]))
    return kb_eff, kr_eff, phi_inf(A, f, p)


def evaluate_softening():
    profile = [
        (600.0, 10.0, 50.0),   # training: 10 min, 50 Hz, 10 m/s^2
        (3600.0, 0.0, 0.0),    # rest: 1 h
        (300.0, 30.0, 200.0),  # erase: 5 min, 200 Hz, 30 m/s^2
    ]
    t, y = simulate_profile(np.array([0.85]), profile, lambda y, A, f, p: np.array([softening_deriv(y[0], A, f, p)]))
    phi_train = y[np.where(np.isclose(t, 600.0))[0][0], 0]
    phi_rest = y[np.where(np.isclose(t, 4200.0))[0][0], 0]
    phi_erase = y[-1, 0]
    E_base = effective_modulus(0.85)
    E_train = effective_modulus(phi_train)
    E_rest = effective_modulus(phi_rest)
    E_erase = effective_modulus(phi_erase)
    train_drop = E_train - E_base
    rest_drop = E_rest - E_base
    retention_pct = 100.0 * rest_drop / train_drop if train_drop != 0 else 0.0
    return {
        "phi_train": phi_train,
        "phi_rest": phi_rest,
        "phi_erase": phi_erase,
        "E_base": E_base,
        "E_train": E_train,
        "E_rest": E_rest,
        "E_erase": E_erase,
        "delta_train_pct": 100.0 * train_drop / E_base,
        "retention_pct": retention_pct,
        "erase_recovery_pct": 100.0 * (E_erase - E_base) / E_base,
    }


def evaluate_extended():
    profile_train = [(600.0, 10.0, 50.0)]
    profile_high = [(300.0, 5.0, 150.0)]

    t1, y1 = simulate_profile(np.array([0.85, 0.0]), profile_train, extended_deriv)
    t2, y2 = simulate_profile(y1[-1], profile_high, extended_deriv)
    phi_train = y1[-1, 0]
    psi_train = y1[-1, 1]
    phi_high = y2[-1, 0]
    psi_high = y2[-1, 1]
    E_soft = effective_modulus(phi_train, psi_train)
    E_high = effective_modulus(phi_high, psi_high)
    E_base = effective_modulus(0.85, 0.0)
    return {
        "phi_train": phi_train,
        "psi_train": psi_train,
        "E_train": E_soft,
        "phi_high": phi_high,
        "psi_high": psi_high,
        "E_high": E_high,
        "E_base": E_base,
        "delta_train_pct": 100.0 * (E_soft - E_base) / E_base,
        "delta_high_pct": 100.0 * (E_high - E_base) / E_base,
    }


def evaluate_sequence(cycles=3):
    profile = []
    for _ in range(cycles):
        profile.append((600.0, 10.0, 50.0))
        profile.append((300.0, 5.0, 150.0))
    t, y = simulate_profile(np.array([params["phi0"], 0.0]), profile, extended_deriv)
    results = []
    elapsed = 0.0
    for idx, (duration, A, f) in enumerate(profile, start=1):
        elapsed += duration
        k = int(np.argmin(np.abs(t - elapsed)))
        phi_val, psi_val = y[k]
        E_val = effective_modulus(phi_val, psi_val)
        results.append({
            "step": idx,
            "duration": duration,
            "A": A,
            "f": f,
            "time": elapsed,
            "phi": float(phi_val),
            "psi": float(psi_val),
            "E": float(E_val),
        })
    return results


def evaluate_mode(name: str):
    profile = profile_for_mode(name)
    state0 = np.array([params["phi0"], 0.0])
    t, y = simulate_profile(state0, profile, extended_deriv)
    phi_final, psi_final = y[-1]
    return {
        "name": name,
        "phi_final": float(phi_final),
        "psi_final": float(psi_final),
        "E_final": float(effective_modulus(phi_final, psi_final)),
        "duration": sum(step[0] for step in profile),
        "profile": profile,
    }


def print_results():
    print("=== Softening-only model ===")
    kb_train, kr_train, phi_inf_train = pulse_coefficients(10.0, 50.0)
    kb_erase, kr_erase, phi_inf_erase = pulse_coefficients(30.0, 200.0)
    print(f"training pulse: kb_eff={kb_train:.3e}, kr_eff={kr_train:.3e}, phi_inf={phi_inf_train:.4f}")
    print(f"erase pulse: kb_eff={kb_erase:.3e}, kr_eff={kr_erase:.3e}, phi_inf={phi_inf_erase:.4f}")
    s = evaluate_softening()
    print(f"baseline modulus E0 = {s['E_base']:.4f} MPa")
    print(f"after 10 min training: phi = {s['phi_train']:.4f}, E = {s['E_train']:.4f} MPa")
    print(f"-> relative change = {s['delta_train_pct']:+.2f} %")
    print(f"after 1 h rest: phi = {s['phi_rest']:.4f}, E = {s['E_rest']:.4f} MPa")
    print(f"-> retention vs trained state = {s['retention_pct']:.2f} %")
    print(f"after erase pulse: phi = {s['phi_erase']:.4f}, E = {s['E_erase']:.4f} MPa")
    print(f"-> recovery vs baseline = {s['erase_recovery_pct']:+.2f} %")
    print()
    print("=== Extended model with orientation channel ===")
    e = evaluate_extended()
    print(f"baseline E0 = {e['E_base']:.4f} MPa")
    print(f"after low-frequency training: phi = {e['phi_train']:.4f}, psi = {e['psi_train']:.4f}, E = {e['E_train']:.4f} MPa")
    print(f"-> relative change = {e['delta_train_pct']:+.2f} %")
    print(f"after high-frequency pulse: phi = {e['phi_high']:.4f}, psi = {e['psi_high']:.4f}, E = {e['E_high']:.4f} MPa")
    print(f"-> relative change = {e['delta_high_pct']:+.2f} %")
    print()
    print("=== Three-cycle low/high sequence ===")
    seq = evaluate_sequence(3)
    for entry in seq:
        step_type = 'train' if entry['f'] == 50.0 else 'high'
        print(f"step {entry['step']:d} ({step_type}, A={entry['A']:.1f}, f={entry['f']:.1f}): "
              f"t={entry['time']:.0f}s, phi={entry['phi']:.4f}, psi={entry['psi']:.4f}, E={entry['E']:.4f} MPa")
    print()
    print("=== Mode library summary ===")
    for name in available_modes():
        print(f"- {describe_mode(name)}")
    print()
    for name in available_modes():
        m = evaluate_mode(name)
        print(f"mode {m['name']}: duration={m['duration']:.0f}s, phi={m['phi_final']:.4f}, psi={m['psi_final']:.4f}, E={m['E_final']:.4f} MPa")


if __name__ == '__main__':
    print_results()
