"""
PyPhysDisc Core Module  (v5)
============================
Co-evolutionary Symbolic Regression with adaptive Savitzky-Golay smoothing.

Changes in v5
-------------
- PDE support: build_cache_pde for spatiotemporal data with TWO window genes
  (temporal + spatial), enabling Burgers / KdV / reaction-diffusion discovery.
- _evaluate_single_pde for PDE individuals.
- PyPhysDiscPDEOptimizer: co-evolves expression + w_time + w_space.

All ODE functionality from v4 is preserved unchanged.

Author : Ali Tozar
License: MIT
"""

from __future__ import annotations

import operator
import random
import warnings
from dataclasses import dataclass
from typing import List, Optional, Sequence, Tuple

import numpy as np
from scipy.signal import savgol_filter

try:
    from deap import algorithms, base, creator, gp, tools
except ImportError as exc:
    raise ImportError("DEAP is required: pip install deap") from exc

warnings.filterwarnings("ignore")


# ---------------------------------------------------------------------------
# Module-level callables (pickle safe)
# ---------------------------------------------------------------------------

def _ephemeral_const() -> float:
    return random.uniform(-2.0, 2.0)


def _safe_exp(x):
    return np.exp(np.clip(x, -20.0, 20.0))


def _safe_div(a, b):
    with np.errstate(divide="ignore", invalid="ignore"):
        return np.where(np.abs(b) > 1e-10, a / b, 0.0)


# ---------------------------------------------------------------------------
# Data helpers
# ---------------------------------------------------------------------------

def sanitize_window(w: int, n: int) -> int:
    w = max(5, int(w))
    if w % 2 == 0:
        w += 1
    limit = n - 1 if (n - 1) % 2 == 1 else n - 2
    return min(w, max(5, limit))


def r2_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    ss_tot = float(np.sum((y_true - y_true.mean()) ** 2))
    if ss_tot < 1e-12:
        return 0.0
    return 1.0 - float(np.sum((y_true - y_pred) ** 2)) / ss_tot


def affine_scale_train(
    y_true_tr: np.ndarray, y_pred_tr: np.ndarray
) -> Tuple[float, float]:
    yp_c = y_pred_tr - y_pred_tr.mean()
    yt_c = y_true_tr - y_true_tr.mean()
    var_p = float(np.dot(yp_c, yp_c))
    if var_p < 1e-12:
        return 0.0, float(y_true_tr.mean())
    slope = float(np.dot(yp_c, yt_c) / var_p)
    intercept = float(y_true_tr.mean() - slope * y_pred_tr.mean())
    return slope, intercept


# ===================================================================
# ODE CACHE
# ===================================================================

def build_cache(
    x_raw: np.ndarray,
    window_pool: Sequence[int],
    dt: float,
    poly_order: int = 3,
    split: Tuple[float, float, float] = (0.6, 0.2, 0.2),
    target_col: int = -1,
) -> dict:
    """
    Pre-compute smoothed states and target derivative for every window.
    Split BEFORE filtering to prevent boundary leakage.
    """
    if x_raw.ndim == 1:
        x_raw = x_raw[:, None]
    T, D = x_raw.shape
    tcol = int(target_col) % D

    n_tr = int(T * split[0])
    n_va = int(T * split[1])

    raw_tr = x_raw[:n_tr]
    raw_va = x_raw[n_tr: n_tr + n_va]
    raw_te = x_raw[n_tr + n_va:]

    cache: dict = {}
    for w_raw in window_pool:
        entry = {}
        for key, chunk in [("tr", raw_tr), ("va", raw_va), ("te", raw_te)]:
            n_c = len(chunk)
            w = sanitize_window(int(w_raw), n_c)
            smooth = np.column_stack(
                [savgol_filter(chunk[:, c], w, poly_order) for c in range(D)]
            )
            deriv = savgol_filter(chunk[:, tcol], w, poly_order, deriv=1, delta=dt)
            entry[f"X_{key}"] = [smooth[:, c] for c in range(D)]
            entry[f"y_{key}"] = deriv
        cache[w_raw] = entry
    return cache


# ===================================================================
# PDE CACHE  (NEW in v5)
# ===================================================================

def build_cache_pde(
    u_raw: np.ndarray,
    wt_pool: Sequence[int],
    ws_pool: Sequence[int],
    dt: float,
    dx: float,
    poly_order: int = 3,
    split: Tuple[float, float, float] = (0.6, 0.2, 0.2),
) -> dict:
    """
    Pre-compute smoothed fields and spatial/temporal derivatives for a PDE.

    Parameters
    ----------
    u_raw   : (Nt, Nx) noisy scalar field.
    wt_pool : candidate temporal SG window sizes.
    ws_pool : candidate spatial SG window sizes.
    dt, dx  : grid spacings.
    split   : time-direction train/val/test fractions.

    Returns
    -------
    cache : { (wt, ws) -> { u_tr, u_x_tr, u_xx_tr, u_t_tr, ... } }
    
    GP terminals: u, u_x, u_xx  (and optionally u_xxx)
    GP target:    u_t
    """
    Nt, Nx = u_raw.shape
    n_tr = int(Nt * split[0])
    n_va = int(Nt * split[1])

    raw_tr = u_raw[:n_tr, :]
    raw_va = u_raw[n_tr: n_tr + n_va, :]
    raw_te = u_raw[n_tr + n_va:, :]

    cache = {}
    for wt in wt_pool:
        for ws in ws_pool:
            entry = {}
            for key, chunk in [("tr", raw_tr), ("va", raw_va), ("te", raw_te)]:
                nt_c, nx_c = chunk.shape
                wt_s = sanitize_window(wt, nt_c)
                ws_s = sanitize_window(ws, nx_c)

                # Step 1: smooth in time (along axis=0)
                u_smooth = np.zeros_like(chunk)
                for j in range(nx_c):
                    u_smooth[:, j] = savgol_filter(chunk[:, j], wt_s, poly_order)

                # Step 2: smooth in space (along axis=1) 
                for i in range(nt_c):
                    u_smooth[i, :] = savgol_filter(u_smooth[i, :], ws_s, poly_order)

                # Temporal derivative: du/dt from TIME-smoothed field
                # NOTE: must NOT use raw chunk here — SavGol deriv amplifies
                # noise by O(1/dt); with dt≈0.0025 that is ~400× amplification.
                # u_smooth has already been smoothed in time (Step 1) then space
                # (Step 2); computing u_t from it gives a clean, consistent target.
                u_t = np.zeros_like(u_smooth)
                for j in range(nx_c):
                    u_t[:, j] = savgol_filter(
                        u_smooth[:, j], wt_s, poly_order, deriv=1, delta=dt
                    )

                # Spatial derivatives via SG along space axis
                u_x = np.zeros_like(u_smooth)
                u_xx = np.zeros_like(u_smooth)
                for i in range(nt_c):
                    u_x[i, :] = savgol_filter(
                        u_smooth[i, :], ws_s, poly_order, deriv=1, delta=dx
                    )
                    u_xx[i, :] = savgol_filter(
                        u_smooth[i, :], ws_s, poly_order, deriv=2, delta=dx
                    )

                # Flatten to 1D for GP
                entry[f"u_{key}"] = u_smooth.ravel()
                entry[f"u_x_{key}"] = u_x.ravel()
                entry[f"u_xx_{key}"] = u_xx.ravel()
                entry[f"u_t_{key}"] = u_t.ravel()

            cache[(wt, ws)] = entry
    return cache


# ===================================================================
# ODE FITNESS
# ===================================================================

def _evaluate_single(individual, cache: dict, pset) -> Tuple[float]:
    try:
        w = individual.window
        data = cache[w]
        func = gp.compile(individual, pset=pset)

        y_hat_tr = func(*data["X_tr"])
        if np.ndim(y_hat_tr) == 0:
            y_hat_tr = np.full_like(data["y_tr"], float(y_hat_tr))
        if not np.isfinite(y_hat_tr).all() or np.var(y_hat_tr) < 1e-12:
            return (1e9,)

        slope, intercept = affine_scale_train(data["y_tr"], y_hat_tr)

        y_hat_va = func(*data["X_va"])
        if np.ndim(y_hat_va) == 0:
            y_hat_va = np.full_like(data["y_va"], float(y_hat_va))
        if not np.isfinite(y_hat_va).all():
            return (1e9,)

        y_opt_va = slope * y_hat_va + intercept
        mse = float(np.mean((data["y_va"] - y_opt_va) ** 2))
        var_va = float(np.var(data["y_va"]))
        return (mse / var_va if var_va > 1e-6 else mse,)
    except Exception:
        return (1e9,)


# ===================================================================
# PDE FITNESS  (NEW in v5)
# ===================================================================

def _evaluate_single_pde(individual, cache: dict, pset) -> Tuple[float]:
    """Fitness for PDE individuals: terminals are u, u_x, u_xx; target is u_t."""
    try:
        wt = individual.window_t
        ws = individual.window_s
        data = cache[(wt, ws)]
        func = gp.compile(individual, pset=pset)

        y_hat_tr = func(data["u_tr"], data["u_x_tr"], data["u_xx_tr"])
        if np.ndim(y_hat_tr) == 0:
            y_hat_tr = np.full_like(data["u_t_tr"], float(y_hat_tr))
        if not np.isfinite(y_hat_tr).all() or np.var(y_hat_tr) < 1e-12:
            return (1e9,)

        slope, intercept = affine_scale_train(data["u_t_tr"], y_hat_tr)

        y_hat_va = func(data["u_va"], data["u_x_va"], data["u_xx_va"])
        if np.ndim(y_hat_va) == 0:
            y_hat_va = np.full_like(data["u_t_va"], float(y_hat_va))
        if not np.isfinite(y_hat_va).all():
            return (1e9,)

        y_opt_va = slope * y_hat_va + intercept
        mse = float(np.mean((data["u_t_va"] - y_opt_va) ** 2))
        var_va = float(np.var(data["u_t_va"]))
        return (mse / var_va if var_va > 1e-6 else mse,)
    except Exception:
        return (1e9,)


# ===================================================================
# Primitive set factory
# ===================================================================

def make_primitive_set(
    var_names: List[str],
    include_sin: bool = False,
    include_cos: bool = False,
    include_exp: bool = False,
    include_square: bool = True,
    include_div: bool = False,
) -> gp.PrimitiveSet:
    pset = gp.PrimitiveSet("MAIN", len(var_names))
    pset.addPrimitive(operator.add, 2)
    pset.addPrimitive(operator.sub, 2)
    pset.addPrimitive(operator.mul, 2)
    pset.addPrimitive(operator.neg, 1)
    if include_square:
        pset.addPrimitive(np.square, 1)
    if include_sin:
        pset.addPrimitive(np.sin, 1)
    if include_cos:
        pset.addPrimitive(np.cos, 1)
    if include_exp:
        pset.addPrimitive(_safe_exp, 1)
    if include_div:
        pset.addPrimitive(_safe_div, 2)
    pset.addEphemeralConstant("const", _ephemeral_const)
    for i, name in enumerate(var_names):
        pset.renameArguments(**{f"ARG{i}": name})
    return pset


# ===================================================================
# Variation operators
# ===================================================================

def _make_individual(pcls, expr_fn, window_pool):
    ind = tools.initIterate(pcls, expr_fn)
    ind.window = random.choice(window_pool)
    return ind


def _make_individual_pde(pcls, expr_fn, wt_pool, ws_pool):
    ind = tools.initIterate(pcls, expr_fn)
    ind.window_t = random.choice(wt_pool)
    ind.window_s = random.choice(ws_pool)
    return ind


def _mate_coupled(ind1, ind2, window_cx_prob=0.5):
    gp.cxOnePoint(ind1, ind2)
    if random.random() < window_cx_prob:
        ind1.window, ind2.window = ind2.window, ind1.window
    return ind1, ind2


def _mate_coupled_pde(ind1, ind2, window_cx_prob=0.5):
    gp.cxOnePoint(ind1, ind2)
    if random.random() < window_cx_prob:
        ind1.window_t, ind2.window_t = ind2.window_t, ind1.window_t
        ind1.window_s, ind2.window_s = ind2.window_s, ind1.window_s
    return ind1, ind2


def _mutate_coupled(ind, expr, pset, window_pool, window_mut_prob=0.2):
    if random.random() < window_mut_prob:
        ind.window = random.choice(window_pool)
    else:
        gp.mutUniform(ind, expr, pset)
    return (ind,)


def _mutate_coupled_pde(ind, expr, pset, wt_pool, ws_pool, window_mut_prob=0.2):
    r = random.random()
    if r < window_mut_prob:
        ind.window_t = random.choice(wt_pool)
    elif r < 2 * window_mut_prob:
        ind.window_s = random.choice(ws_pool)
    else:
        gp.mutUniform(ind, expr, pset)
    return (ind,)


# ===================================================================
# Config
# ===================================================================

@dataclass
class PyPhysDiscConfig:
    pop_size: int = 500
    n_gen: int = 30
    cx_prob: float = 0.6
    mut_prob: float = 0.3
    window_mut_prob: float = 0.2
    window_cx_prob: float = 0.5
    tournament_fitness_size: int = 7
    parsimony_size: float = 1.3
    init_min_depth: int = 1
    init_max_depth: int = 4
    seed: int = 0


# ===================================================================
# DEAP creator (once per process)
# ===================================================================
_CREATOR_INITIALIZED = False


def _ensure_creator():
    global _CREATOR_INITIALIZED
    if _CREATOR_INITIALIZED:
        return
    creator.create("FitnessMin", base.Fitness, weights=(-1.0,))
    creator.create(
        "Individual", gp.PrimitiveTree,
        fitness=creator.FitnessMin, window=None,
        window_t=None, window_s=None,
    )
    _CREATOR_INITIALIZED = True


# ===================================================================
# ODE Optimizer
# ===================================================================

class PyPhysDiscOptimizer:
    def __init__(
        self,
        var_names: List[str],
        window_pool: Sequence[int],
        cache: dict,
        pset=None,
        config=None,
    ):
        _ensure_creator()
        self.var_names = list(var_names)
        self.window_pool = list(window_pool)
        self.cache = cache
        self.config = config or PyPhysDiscConfig()
        self.pset = pset or make_primitive_set(var_names)
        self.hof = tools.HallOfFame(1)
        self.logbook: list = []

        cfg = self.config
        wp = self.window_pool
        tb = base.Toolbox()
        self.toolbox = tb

        tb.register("expr", gp.genHalfAndHalf, pset=self.pset,
                     min_=cfg.init_min_depth, max_=cfg.init_max_depth)
        tb.register("individual", _make_individual,
                     creator.Individual, tb.expr, wp)
        tb.register("population", tools.initRepeat, list, tb.individual)
        tb.register("compile", gp.compile, pset=self.pset)
        tb.register("mate", _mate_coupled, window_cx_prob=cfg.window_cx_prob)
        tb.register("mutate", _mutate_coupled,
                     expr=tb.expr, pset=self.pset,
                     window_pool=wp, window_mut_prob=cfg.window_mut_prob)
        tb.register("select", tools.selDoubleTournament,
                     fitness_size=cfg.tournament_fitness_size,
                     parsimony_size=cfg.parsimony_size,
                     fitness_first=True)

    def fit(self, verbose=False):
        cfg = self.config
        random.seed(cfg.seed)
        np.random.seed(cfg.seed)

        _cache, _pset = self.cache, self.pset
        self.toolbox.register("evaluate",
                              lambda ind: _evaluate_single(ind, _cache, _pset))

        stats = tools.Statistics(
            lambda ind: ind.fitness.values[0] if ind.fitness.valid else 1e9
        )
        stats.register("min", np.min)
        stats.register("mean", np.mean)

        pop = self.toolbox.population(n=cfg.pop_size)
        pop, log = algorithms.eaSimple(
            pop, self.toolbox,
            cxpb=cfg.cx_prob, mutpb=cfg.mut_prob, ngen=cfg.n_gen,
            stats=stats, halloffame=self.hof, verbose=verbose,
        )
        self.logbook = log
        return self

    def predict_test(self):
        best = self.hof[0]
        w = best.window
        data = self.cache[w]
        func = gp.compile(best, pset=self.pset)

        y_hat_tr = func(*data["X_tr"])
        if np.ndim(y_hat_tr) == 0:
            y_hat_tr = np.full_like(data["y_tr"], float(y_hat_tr))
        slope, intercept = affine_scale_train(data["y_tr"], y_hat_tr)

        y_hat_te = func(*data["X_te"])
        if np.ndim(y_hat_te) == 0:
            y_hat_te = np.full_like(data["y_te"], float(y_hat_te))
        y_pred = slope * y_hat_te + intercept
        r2 = r2_score(data["y_te"], y_pred)
        return data["y_te"], y_pred, r2

    @property
    def best_expression(self):
        return str(self.hof[0])

    @property
    def best_window(self):
        return self.hof[0].window

    @property
    def best_size(self):
        return len(self.hof[0])


# ===================================================================
# PDE Optimizer  (NEW in v5)
# ===================================================================

class PyPhysDiscPDEOptimizer:
    """
    Co-evolutionary GP for PDE discovery.
    
    Each individual carries:
      * symbolic expression tree (terminals: u, u_x, u_xx)
      * temporal window gene  (w_t)
      * spatial window gene   (w_s)
    """

    def __init__(
        self,
        wt_pool: Sequence[int],
        ws_pool: Sequence[int],
        cache: dict,
        pset=None,
        config=None,
    ):
        _ensure_creator()
        self.wt_pool = list(wt_pool)
        self.ws_pool = list(ws_pool)
        self.cache = cache
        self.config = config or PyPhysDiscConfig()
        self.pset = pset or make_primitive_set(["u", "u_x", "u_xx"])
        self.hof = tools.HallOfFame(1)
        self.logbook: list = []

        cfg = self.config
        tb = base.Toolbox()
        self.toolbox = tb

        tb.register("expr", gp.genHalfAndHalf, pset=self.pset,
                     min_=cfg.init_min_depth, max_=cfg.init_max_depth)
        tb.register("individual", _make_individual_pde,
                     creator.Individual, tb.expr, self.wt_pool, self.ws_pool)
        tb.register("population", tools.initRepeat, list, tb.individual)
        tb.register("compile", gp.compile, pset=self.pset)
        tb.register("mate", _mate_coupled_pde,
                     window_cx_prob=cfg.window_cx_prob)
        tb.register("mutate", _mutate_coupled_pde,
                     expr=tb.expr, pset=self.pset,
                     wt_pool=self.wt_pool, ws_pool=self.ws_pool,
                     window_mut_prob=cfg.window_mut_prob)
        tb.register("select", tools.selDoubleTournament,
                     fitness_size=cfg.tournament_fitness_size,
                     parsimony_size=cfg.parsimony_size,
                     fitness_first=True)

    def fit(self, verbose=False):
        cfg = self.config
        random.seed(cfg.seed)
        np.random.seed(cfg.seed)

        _cache, _pset = self.cache, self.pset
        self.toolbox.register("evaluate",
                              lambda ind: _evaluate_single_pde(ind, _cache, _pset))

        stats = tools.Statistics(
            lambda ind: ind.fitness.values[0] if ind.fitness.valid else 1e9
        )
        stats.register("min", np.min)
        stats.register("mean", np.mean)

        pop = self.toolbox.population(n=cfg.pop_size)
        pop, log = algorithms.eaSimple(
            pop, self.toolbox,
            cxpb=cfg.cx_prob, mutpb=cfg.mut_prob, ngen=cfg.n_gen,
            stats=stats, halloffame=self.hof, verbose=verbose,
        )
        self.logbook = log
        return self

    def predict_test(self):
        best = self.hof[0]
        wt, ws = best.window_t, best.window_s
        data = self.cache[(wt, ws)]
        func = gp.compile(best, pset=self.pset)

        y_hat_tr = func(data["u_tr"], data["u_x_tr"], data["u_xx_tr"])
        if np.ndim(y_hat_tr) == 0:
            y_hat_tr = np.full_like(data["u_t_tr"], float(y_hat_tr))
        slope, intercept = affine_scale_train(data["u_t_tr"], y_hat_tr)

        y_hat_te = func(data["u_te"], data["u_x_te"], data["u_xx_te"])
        if np.ndim(y_hat_te) == 0:
            y_hat_te = np.full_like(data["u_t_te"], float(y_hat_te))
        y_pred = slope * y_hat_te + intercept
        r2 = r2_score(data["u_t_te"], y_pred)
        return data["u_t_te"], y_pred, r2

    @property
    def best_expression(self):
        return str(self.hof[0])

    @property
    def best_window_t(self):
        return self.hof[0].window_t

    @property
    def best_window_s(self):
        return self.hof[0].window_s

    @property
    def best_size(self):
        return len(self.hof[0])


# ===================================================================
# Version info
# ===================================================================

def get_version_info() -> dict:
    import sys
    v = {"python": sys.version, "numpy": np.__version__}
    try:
        v["scipy"] = __import__("scipy").__version__
    except Exception:
        pass
    try:
        v["deap"] = __import__("deap").__version__
    except Exception:
        pass
    return v