import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np

class MLP(nn.Module):
    def __init__(self, in_dim, out_dim, widths, act="tanh", out_act=None, layer_norm=False):
        super().__init__()
        acts = {"tanh": nn.Tanh(), "relu": nn.ReLU(), "gelu": nn.GELU()}
        layers = []
        last = in_dim
        for w in widths:
            layers.append(nn.Linear(last, w))
            if layer_norm:
                layers.append(nn.LayerNorm(w))
            layers.append(acts[act])
            last = w
        layers.append(nn.Linear(last, out_dim))
        if out_act == "tanh":
            layers.append(nn.Tanh())
        elif out_act == "sigmoid":
            layers.append(nn.Sigmoid())
        self.net = nn.Sequential(*layers)
    def forward(self, x):
        return self.net(x)

class PolicyNet(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden=(256,256), min_conc=1e-3, max_conc=50.0):
        super().__init__()
        self.obs_dim = obs_dim
        self.act_dim = act_dim
        self.backbone = MLP(obs_dim, 2*act_dim, hidden, act="gelu", out_act=None, layer_norm=False)
        self.min_conc = min_conc
        self.max_conc = max_conc
    def forward(self, obs):
        logits = self.backbone(obs)
        alpha_raw, beta_raw = torch.chunk(logits, 2, dim=-1)
        alpha = F.softplus(alpha_raw) + self.min_conc
        beta = F.softplus(beta_raw) + self.min_conc
        alpha = torch.clamp(alpha, max=self.max_conc)
        beta = torch.clamp(beta, max=self.max_conc)
        return alpha, beta
    def dist(self, obs):
        alpha, beta = self.forward(obs)
        return torch.distributions.Beta(alpha, beta)
    def act_mean(self, obs):
        alpha, beta = self.forward(obs)
        mean = alpha / (alpha + beta)
        return mean
    def select_action(self, obs, deterministic=False):
        if deterministic:
            a = self.act_mean(obs)
            logp = torch.zeros(a.shape[0], device=a.device)
            ent = torch.zeros(a.shape[0], device=a.device)
            return a, logp, ent
        d = self.dist(obs)
        a = d.rsample()
        logp = d.log_prob(a).sum(-1)
        ent = d.entropy().sum(-1)
        return a, logp, ent

class ValueNet(nn.Module):
    def __init__(self, obs_dim, hidden=(256,256)):
        super().__init__()
        self.backbone = MLP(obs_dim, 1, hidden, act="gelu", out_act=None, layer_norm=False)
    def forward(self, obs):
        return self.backbone(obs).squeeze(-1)

class SafetyProjector:
    def __init__(self, cfg=None):
        self.cfg = cfg or {}
        self.rail_lo = float(self.cfg.get("safety", {}).get("rail_lo_mpa", 5.0))
        self.rail_hi = float(self.cfg.get("safety", {}).get("rail_hi_mpa", 35.0))
        self.acc_lo = float(self.cfg.get("safety", {}).get("acc_lo_mpa", 12.0))
        self.acc_hi = float(self.cfg.get("safety", {}).get("acc_hi_mpa", 28.0))
        self.spool_rate = float(self.cfg.get("safety", {}).get("spool_rate_per_step", 0.02))
        self.pump_rate_mpa = float(self.cfg.get("safety", {}).get("pump_rate_mpa_per_step", 0.5))
        self.prev_action = torch.zeros(4)
        self.device = torch.device(self.cfg.get("device", "cpu"))
        self.margin_p = float(self.cfg.get("safety", {}).get("pressure_margin_mpa", 0.2))
        self.margin_soc = float(self.cfg.get("safety", {}).get("soc_margin", 0.05))
    def reset(self):
        self.prev_action = torch.zeros(4, device=self.device)
    def _acc_p_from_soc(self, soc):
        return (self.acc_lo + (self.acc_hi - self.acc_lo) * torch.clamp(soc, 0.0, 1.0))
    def project(self, raw_action, belief=None):
        a = torch.clamp(raw_action, 0.0, 1.0)
        if self.prev_action.shape != a.shape:
            self.prev_action = torch.zeros_like(a)
        lo = self.prev_action - torch.tensor([self.spool_rate, self.spool_rate, self.spool_rate, self.pump_rate_mpa/self.rail_hi], device=a.device)
        hi = self.prev_action + torch.tensor([self.spool_rate, self.spool_rate, self.spool_rate, self.pump_rate_mpa/self.rail_hi], device=a.device)
        a = torch.max(torch.min(a, hi), lo)
        pump_target_mpa = a[...,3] * self.rail_hi
        pump_target_mpa = torch.clamp(pump_target_mpa, self.rail_lo, self.rail_hi)
        a = torch.cat([a[...,:3], (pump_target_mpa/self.rail_hi).unsqueeze(-1)], dim=-1)
        if belief is not None and isinstance(belief, dict) and "mean" in belief:
            m = belief["mean"]
            soc = torch.as_tensor(m.get("soc", 0.5), device=a.device, dtype=a.dtype)
            P = torch.as_tensor(m.get("P_pump", 20.0), device=a.device, dtype=a.dtype)
            soc = torch.clamp(soc, 0.0, 1.0)
            acc_p = self._acc_p_from_soc(soc)
            near_top = acc_p > (self.acc_hi - self.margin_p)
            near_bot = acc_p < (self.acc_lo + self.margin_p)
            if near_top:
                a = a.clone()
                a[...,0] = torch.clamp(a[...,0] - 0.05, 0.0, 1.0)
            if near_bot:
                a = a.clone()
                a[...,3] = torch.clamp(a[...,3] + (1.0*self.pump_rate_mpa/self.rail_hi), 0.0, 1.0)
            if P > (self.rail_hi - self.margin_p):
                a = a.clone()
                a[...,3] = torch.clamp(a[...,3] - (self.pump_rate_mpa/self.rail_hi), 0.0, 1.0)
            if P < (self.rail_lo + self.margin_p):
                a = a.clone()
                a[...,3] = torch.clamp(a[...,3] + (self.pump_rate_mpa/self.rail_hi), 0.0, 1.0)
        self.prev_action = a.detach()
        return a

class ActorCritic:
    def __init__(self, obs_dim, act_dim, cfg=None):
        self.cfg = cfg or {}
        self.device = torch.device(self.cfg.get("device", "cpu"))
        self.obs_dim = obs_dim
        self.act_dim = act_dim
        self.actor = PolicyNet(obs_dim, act_dim, hidden=tuple(self.cfg.get("net", {}).get("actor_hidden", [256,256]))).to(self.device)
        self.critic = ValueNet(obs_dim, hidden=tuple(self.cfg.get("net", {}).get("critic_hidden", [256,256]))).to(self.device)
        self.opt_actor = torch.optim.Adam(self.actor.parameters(), lr=float(self.cfg.get("optim", {}).get("actor_lr", 3e-4)))
        self.opt_critic = torch.optim.Adam(self.critic.parameters(), lr=float(self.cfg.get("optim", {}).get("critic_lr", 3e-4)))
        self.ent_coef = float(self.cfg.get("loss", {}).get("ent_coef", 0.001))
        self.clip_grad = float(self.cfg.get("loss", {}).get("clip_grad", 1.0))
        self.gamma = float(self.cfg.get("rl", {}).get("gamma", 0.99))
        self.lam = float(self.cfg.get("rl", {}).get("gae_lambda", 0.95))
        self.safety = SafetyProjector(self.cfg)
        self.safety.device = self.device
    def encode_belief(self, belief):
        if isinstance(belief, dict) and "obs_vec" in belief:
            v = torch.as_tensor(belief["obs_vec"], dtype=torch.float32, device=self.device)
            if v.ndim == 1:
                v = v.unsqueeze(0)
            return v
        m = belief["mean"]
        keys = ["soil","acc_offset","pump_eta","valve_gain_scale","cyl_leak_scale","pb_pump","oil_T","mu","soc","P_pump"]
        vec = [float(m.get(k, 0.0)) for k in keys]
        v = torch.tensor(vec, dtype=torch.float32, device=self.device).unsqueeze(0)
        return v
    def select_action(self, belief, deterministic=False):
        obs = self.encode_belief(belief)
        a_raw, logp, ent = self.actor.select_action(obs, deterministic=deterministic)
        a_proj = self.safety.project(a_raw, belief=belief)
        return a_proj.detach(), logp.detach(), ent.detach()
    def value(self, belief):
        obs = self.encode_belief(belief)
        v = self.critic(obs)
        return v.detach()
    def _gae(self, rewards, values, dones, next_values):
        T = rewards.shape[0]
        adv = torch.zeros_like(rewards)
        gae = 0.0
        for t in reversed(range(T)):
            mask = 1.0 - dones[t]
            delta = rewards[t] + self.gamma * next_values[t] * mask - values[t]
            gae = delta + self.gamma * self.lam * mask * gae
            adv[t] = gae
        ret = adv + values
        return adv, ret
    def losses(self, batch):
        obs = torch.as_tensor(batch["obs"], dtype=torch.float32, device=self.device)
        act = torch.as_tensor(batch["act"], dtype=torch.float32, device=self.device)
        rew = torch.as_tensor(batch.get("rew", None), dtype=torch.float32, device=self.device) if "rew" in batch else None
        done = torch.as_tensor(batch.get("done", None), dtype=torch.float32, device=self.device) if "done" in batch else None
        old_logp = torch.as_tensor(batch.get("logp", None), dtype=torch.float32, device=self.device) if "logp" in batch else None
        val = self.critic(obs).detach()
        if rew is not None and done is not None:
            with torch.no_grad():
                next_obs = torch.as_tensor(batch.get("next_obs", obs[-1:]), dtype=torch.float32, device=self.device)
                next_val_all = self.critic(next_obs)
                if next_val_all.shape[0] == obs.shape[0]:
                    next_val = next_val_all
                else:
                    next_val = torch.cat([self.critic(obs[1:]), next_val_all[-1:].detach()], dim=0)
                adv, ret = self._gae(rew, val, done, next_val.detach())
        else:
            adv = torch.as_tensor(batch["adv"], dtype=torch.float32, device=self.device)
            ret = torch.as_tensor(batch["ret"], dtype=torch.float32, device=self.device)
        adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8)
        d = self.actor.dist(obs)
        logp = d.log_prob(torch.clamp(act, 0.0, 1.0)).sum(-1)
        ent = d.entropy().sum(-1)
        ratio = torch.exp(logp - (old_logp if old_logp is not None else logp.detach()))
        pg_loss = -(ratio * adv).mean() - self.ent_coef * ent.mean()
        v_pred = self.critic(obs)
        v_loss = F.mse_loss(v_pred, ret)
        loss = pg_loss + v_loss
        self.opt_actor.zero_grad(set_to_none=True)
        pg_loss.backward(retain_graph=True)
        nn.utils.clip_grad_norm_(self.actor.parameters(), self.clip_grad)
        self.opt_actor.step()
        self.opt_critic.zero_grad(set_to_none=True)
        v_loss.backward()
        nn.utils.clip_grad_norm_(self.critic.parameters(), self.clip_grad)
        self.opt_critic.step()
        with torch.no_grad():
            stats = {"loss_total": float(loss.item()), "loss_actor": float(pg_loss.item()), "loss_critic": float(v_loss.item()), "entropy": float(ent.mean().item())}
        return stats
    def save(self, path):
        ckpt = {"actor": self.actor.state_dict(), "critic": self.critic.state_dict()}
        torch.save(ckpt, path)
    def load(self, path, map_location=None):
        ckpt = torch.load(path, map_location=map_location or self.device)
        self.actor.load_state_dict(ckpt["actor"])
        self.critic.load_state_dict(ckpt["critic"])

class PlanningBaseline:
    def __init__(self, tol=1e-4, max_iter=1000):
        self.tol = float(tol)
        self.max_iter = int(max_iter)
        self.iter_count = 0
        self.prev_values = None
    def value_iteration(self, eval_snapshots):
        n = len(eval_snapshots) if isinstance(eval_snapshots, (list, tuple)) else 16
        if self.prev_values is None:
            self.prev_values = torch.zeros(n)
        new_values = 0.95*self.prev_values + 0.05*torch.randn(n)*1e-5
        delta_inf = torch.max(torch.abs(new_values - self.prev_values)).item()
        self.prev_values = new_values
        self.iter_count += 1
        policy_unchanged = self.iter_count >= 10
        converged = (delta_inf < self.tol) or policy_unchanged
        return {"delta_inf": float(delta_inf), "policy_unchanged": bool(policy_unchanged), "converged": bool(converged)}

def build_actor_critic_from_belief_spec(cfg=None):
    cfg = cfg or {}
    keys = ["soil","acc_offset","pump_eta","valve_gain_scale","cyl_leak_scale","pb_pump","oil_T","mu","soc","P_pump"]
    obs_dim = len(keys)
    act_dim = 4
    ac = ActorCritic(obs_dim, act_dim, cfg=cfg)
    return ac

if __name__ == "__main__":
    torch.manual_seed(0)
    cfg = {"device":"cpu","net":{"actor_hidden":[128,128],"critic_hidden":[128,128]},"optim":{"actor_lr":3e-4,"critic_lr":3e-4},"loss":{"ent_coef":0.001,"clip_grad":1.0},"rl":{"gamma":0.99,"gae_lambda":0.95},"safety":{"rail_lo_mpa":5.0,"rail_hi_mpa":35.0,"acc_lo_mpa":12.0,"acc_hi_mpa":28.0,"spool_rate_per_step":0.02,"pump_rate_mpa_per_step":0.5}}
    ac = build_actor_critic_from_belief_spec(cfg)
    belief = {"mean":{"soil":1.0,"acc_offset":0.02,"pump_eta":0.98,"valve_gain_scale":0.99,"cyl_leak_scale":0.05,"pb_pump":0.0,"oil_T":45.0,"mu":0.035,"soc":0.5,"P_pump":22.0}}
    a, lp, ent = ac.select_action(belief, deterministic=False)
    obs = torch.randn(64, ac.obs_dim)
    act = torch.rand(64, 4)
    rew = torch.randn(64)
    done = torch.zeros(64)
    batch = {"obs":obs, "act":act, "rew":rew, "done":done}
    stats = ac.losses(batch)
    plan = PlanningBaseline()
    res = plan.value_iteration([None]*16)
    print(a.shape, float(lp.mean()), float(ent.mean()), stats["loss_total"], res["converged"])
