import numpy as np
from ..hierarchy.level import Level
from ..hierarchy.operators import connection_penalty, grad_connection_wrt_lower, grad_connection_wrt_upper
from ..gra_step.critic_revisor import gra_nullify
from ..stability.checker import check_stability

class HierarchicalOptimizer:
    """
    Полный оптимизатор, минимизирующий функционал J = Σ α_l Φ(x_l) + Σ β_l C_l + γ Ψ(x_L)
    с помощью чередования GRA‑шагов и градиентного спуска, завершающийся при 
    достижении устойчивости.
    """
    def __init__(self, levels, foam_metrics, alpha=None, beta=None, gamma=1.0,
                 A_configs=None, terminal_penalty_fn=None, lr=0.01):
        self.levels = levels  # список Level
        self.foam_metrics = foam_metrics  # список BaseFoam (или один для всех)
        self.L = len(levels)
        self.alpha = alpha if alpha else [1.0] * self.L
        self.beta = beta if beta else [1.0] * (self.L - 1)
        self.gamma = gamma
        self.A_configs = A_configs or [None] * (self.L - 1)
        self.terminal_penalty_fn = terminal_penalty_fn  # Ψ(x_L), если есть
        self.lr = lr

    def compute_J(self):
        J_val = 0.0
        for l in range(self.L):
            foam = self.foam_metrics[l].compute(self.levels[l].state, self.levels[l].context) if isinstance(self.foam_metrics, list) else self.foam_metrics.compute(self.levels[l].state, self.levels[l].context)
            J_val += self.alpha[l] * foam
        for l in range(self.L - 1):
            C = connection_penalty(self.levels[l].state, self.levels[l+1].state, self.A_configs[l])
            J_val += self.beta[l] * C
        if self.terminal_penalty_fn:
            J_val += self.gamma * self.terminal_penalty_fn(self.levels[-1].state)
        return J_val

    def step(self):
        # 1. GRA-обнуление на каждом уровне (уменьшает Φ)
        for l in range(self.L):
            self.levels[l].state = gra_nullify(
                self.levels[l].state,
                self.foam_metrics[l] if isinstance(self.foam_metrics, list) else self.foam_metrics,
                self.levels[l].context
            )
        # 2. Градиентный спуск по J (упрощённый: считаем частные производные)
        for l in range(self.L):
            grad = np.zeros_like(self.levels[l].state)
            # от α_l Φ
            foam_grad = self.foam_metrics[l].grad(self.levels[l].state, self.levels[l].context) if isinstance(self.foam_metrics, list) else self.foam_metrics.grad(self.levels[l].state, self.levels[l].context)
            grad += self.alpha[l] * foam_grad
            # от C_l и C_{l-1}
            if l < self.L - 1:
                grad += self.beta[l] * grad_connection_wrt_lower(self.levels[l].state, self.levels[l+1].state, self.A_configs[l])
            if l > 0:
                grad += self.beta[l-1] * grad_connection_wrt_upper(self.levels[l-1].state, self.levels[l].state, self.A_configs[l-1])
            # терминальный штраф
            if self.terminal_penalty_fn and l == self.L - 1:
                # считаем градиент численно
                eps = 1e-6
                state = self.levels[l].state
                term_grad = np.zeros_like(state)
                for i in range(len(state)):
                    state_plus = state.copy()
                    state_plus[i] += eps
                    state_minus = state.copy()
                    state_minus[i] -= eps
                    f_plus = self.terminal_penalty_fn(state_plus)
                    f_minus = self.terminal_penalty_fn(state_minus)
                    term_grad[i] = (f_plus - f_minus) / (2*eps)
                grad += self.gamma * term_grad
            self.levels[l].state -= self.lr * grad

    def optimize(self, max_iter=1000, tol=1e-6, check_stability_every=10):
        for it in range(max_iter):
            old_states = [lvl.state.copy() for lvl in self.levels]
            self.step()
            # критерий остановки: изменение state мало и устойчивость достигнута
            max_diff = max(np.max(np.abs(old_states[i] - self.levels[i].state)) for i in range(self.L))
            if max_diff < tol:
                if check_stability(self.levels, self.foam_metrics):
                    print(f"Stability reached at iteration {it}")
                    break
        return self.levels
