
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from scipy.optimize import least_squares

# =====================================================
# 1. Blasius numerical solution
# =====================================================
def blasius_ode(eta, y):
    f, fp, fpp = y
    return [fp, fpp, -0.5 * f * fpp]

eta = np.linspace(0, 10, 600)
fpp0 = 0.332057336215

sol = solve_ivp(
    blasius_ode,
    [0, 10],
    [0, 0, fpp0],
    t_eval=eta,
    rtol=1e-9,
    atol=1e-11
)

f = sol.y[0]
fp = sol.y[1]
fpp = sol.y[2]
fppp = -0.5 * f * fpp

# avoid singular region near f=0
mask = f > 1e-6
f_data = f[mask]
eta_data = eta[mask]
fp_data = fp[mask]
fpp_data = fpp[mask]
fppp_data = fppp[mask]

# =====================================================
# 2. Model Components (modular, no change in math)
# =====================================================
def compute_BC(f, params):
    A, b, c, g, h, i, j, l, m, n, *_ = params

    expB = np.exp(np.clip(g * f, -50, 50))
    expC = np.exp(np.clip(l * f, -50, 50))

    B = b * expB + h * f + i
    C = j * expC + m * f + n

    return A, B, C


def quadratic_root(A, B, C):
    disc = np.maximum(B**2 - 4*A*C, 1e-10)
    return (-B - np.sqrt(disc)) / (2*A), disc


def correction_terms(f, eta, params):
    _, _, _, _, _, _, _, _, _, _, o, p2, q, r, s, t = params

    return (
        o * f**3
        + p2 * f**2
        + q * f
        + r * np.sqrt(f)
        + s * f**0.25
        + t
        + 0.332 * eta
    )


def model_fprime(f, eta, params):
    A, b, c, *_ = params

    A, B, C = compute_BC(f, params)
    f_quad, disc = quadratic_root(A, B, C)

    f_raw = f_quad + correction_terms(f, eta, params)

    fp_model = 1.0 * (1 - np.exp(-(f_raw + c)))

    return fp_model, f_quad, B, C, disc


# =====================================================
# 3. Residual function
# =====================================================
def residuals(params):
    fp_model, _, _, _, _ = model_fprime(f_data, eta_data, params)
    return fp_model - fp_data


# =====================================================
# 4. Initial guess and bounds
# =====================================================
p0 = [
    -0.28, -2.5, -0.03, -0.4,
    3.1, 0.05, -0.3, 0.34,
    0.40, 0.0, 0.0, 0.0,
    0.0, 0.0, 0.0, 0.0
]

bounds = (
    [-1.0, -10, -1, -5, -10, -5, -10, -5, -5, -5, -1, -1, -1, -1, -1, -1],
    [-0.01, 10,  1,  5,  10,  5,  10,  5,  5,  5,  1,  1,  1,  1,  1,  1]
)

# =====================================================
# 5. Optimization
# =====================================================
result = least_squares(
    residuals,
    p0,
    bounds=bounds,
    method="trf",
    ftol=1e-12,
    xtol=1e-12,
    gtol=1e-12,
    max_nfev=50000
)

params_opt = result.x
A = params_opt[0]

# =====================================================
# 6. Diagnostics
# =====================================================
fp_fit, f_quad, B, C, disc = model_fprime(f_data, eta_data, params_opt)

error = fp_fit - fp_data
rmse = np.sqrt(np.mean(error**2))

quad_res = A * fp_fit**2 + B * fp_fit + C
blasius_res = 2 * fppp_data + f_data * fpp_data

print("\n=== Optimized Parameters ===")
names = ["A","b","c","g","h","i","j","l","m","n","o","p","q","r","s","t"]
for name, val in zip(names, params_opt):
    print(f"{name} = {val:.6e}")

print("\nRMSE =", rmse)
print("Max quadratic residual:", np.max(np.abs(quad_res)))
print("Max Blasius residual:", np.max(np.abs(blasius_res)))

# =====================================================
# 7. Plots
# =====================================================
plt.figure()
plt.plot(f_data, fp_data, label="Blasius f'")
plt.plot(f_data, fp_fit, '--', label="Model f'")
plt.xlabel("f")
plt.ylabel("f'")
plt.legend()
plt.grid()
plt.show()

plt.figure()
plt.plot(f_data, error)
plt.xlabel("f")
plt.ylabel("Error")
plt.title("Fit Error")
plt.grid()
plt.show()

plt.figure()
plt.plot(f_data, quad_res)
plt.xlabel("f")
plt.ylabel("A f'^2 + B f' + C")
plt.title("Quadratic Residual")
plt.grid()
plt.show()

plt.figure()
plt.plot(f_data, blasius_res)
plt.xlabel("f")
plt.ylabel("2 f''' + f f''")
plt.title("Blasius Identity Residual")
plt.grid()
plt.show()

plt.figure()
plt.plot(f_data, np.abs(error))
plt.yscale('log')
plt.xlabel("f")
plt.ylabel("|error|")
plt.title("Absolute Error (log scale)")
plt.grid(True, which="both", ls="--")
plt.show()