import argparse
import os
import time
import math
import json
import yaml
import numpy as np
import pandas as pd
import torch
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from collections import defaultdict, deque
from pathlib import Path
from env import ExcavatorEnv
from belief import BootstrapPF
from policy import build_actor_critic_from_belief_spec, PlanningBaseline

def load_cfg(path):
    with open(path, "r", encoding="utf-8") as f:
        txt = f.read()
    try:
        return yaml.safe_load(txt)
    except Exception:
        return json.loads(txt)

def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)

def make_dirs(outdir):
    Path(outdir).mkdir(parents=True, exist_ok=True)
    Path(os.path.join(outdir, "logs")).mkdir(parents=True, exist_ok=True)
    Path(os.path.join(outdir, "figs")).mkdir(parents=True, exist_ok=True)
    Path(os.path.join(outdir, "ckpts")).mkdir(parents=True, exist_ok=True)

def make_env(cfg, seed):
    set_seed(seed)
    env = ExcavatorEnv(cfg=cfg)
    obs0 = env.reset(seed=seed, cfg=cfg)
    pf = BootstrapPF(cfg=cfg)
    belief0 = pf.init_from_obs(obs0)
    ac = build_actor_critic_from_belief_spec(cfg=cfg)
    return env, pf, ac, obs0, belief0

def to_vec_from_belief(belief):
    if "obs_vec" in belief:
        v = np.asarray(belief["obs_vec"], dtype=np.float32)
        return v
    m = belief["mean"]
    keys = ["soil","acc_offset","pump_eta","valve_gain_scale","cyl_leak_scale","pb_pump","oil_T","mu","soc","P_pump"]
    return np.array([float(m.get(k, 0.0)) for k in keys], dtype=np.float32)

def normal_cdf(x):
    return 0.5*(1.0+math.erf(x/math.sqrt(2.0)))

def two_sided_p_from_t(t_stat, n1, n2):
    df = max(1, n1+n2-2)
    z = abs(t_stat)
    p = 2.0*(1.0-normal_cdf(z))
    return p

def holm_bonferroni(pvals, alpha=0.05):
    m = len(pvals)
    idx = np.argsort(pvals)
    adjusted = np.zeros(m)
    for rank, i in enumerate(idx, start=1):
        adjusted[i] = min(1.0, pvals[i]*(m - rank + 1))
    return adjusted

def compute_effect_size(a, b):
    ma, mb = np.mean(a), np.mean(b)
    sa, sb = np.std(a, ddof=1), np.std(b, ddof=1)
    s = math.sqrt(((len(a)-1)*sa*sa + (len(b)-1)*sb*sb) / (len(a)+len(b)-2 + 1e-9))
    if s <= 0:
        return 0.0
    return (ma-mb)/s

def baseline_action_fixed(belief):
    return np.array([0.5,0.5,0.5, 22.0/35.0], dtype=np.float32)

def baseline_action_operator(belief):
    m = belief["mean"]
    soil = float(m.get("soil",1.0))
    oilT = float(m.get("oil_T",45.0))
    p_gain = 0.0
    v_scale = 1.0
    if oilT > 55.0 or soil > 1.15:
        p_gain = min(0.10, 0.02 + 0.004*max(0.0, oilT-55.0) + 0.15*max(0.0, soil-1.15))
        v_scale = max(0.85, 1.0 - 0.15)
    sp = np.clip(0.5*v_scale, 0.2, 0.8)
    pump = np.clip((22.0*(1.0+p_gain))/35.0, 5.0/35.0, 1.0)
    return np.array([sp, sp, sp, pump], dtype=np.float32)

def run_episode(env, pf, controller, cfg, horizon_steps, seed, mode, log_step=True, deterministic=False, projector=None, ablate=None):
    dt = cfg.get("env", {}).get("dt", 0.01)
    step_rows = []
    cycle_rows = []
    obs = env.reset(seed=seed, cfg=cfg)
    belief = pf.init_from_obs(obs)
    ep_ret = 0.0
    for t in range(horizon_steps):
        if controller == "BRL":
            a_torch, logp, ent = projector.select_action(belief, deterministic=deterministic)
            act = a_torch.cpu().numpy().reshape(-1)
        elif controller == "Traditional":
            act = baseline_action_fixed(belief)
        else:
            act = baseline_action_operator(belief)
        if ablate is not None:
            if ablate.get("no_proj", False) and controller == "BRL":
                vec = projector.encode_belief(belief)
                mu = projector.actor.act_mean(vec)
                act = np.clip(mu.detach().cpu().numpy().reshape(-1), 0.0, 1.0)
        obs_next, reward, done, info = env.step(act)
        pf.predict(act, dt)
        diag = pf.update(obs_next)
        pf.log_step(t, obs_next, diag=diag, truth=info.get("truth", {}))
        belief = pf.estimate()
        ep_ret += float(reward if np.isscalar(reward) else reward.get("total", 0.0))
        if log_step:
            row = {}
            row["step"] = t
            row["time_s"] = t*dt
            row["seed"] = seed
            row["controller"] = controller
            row["reward_total"] = float(reward if np.isscalar(reward) else reward.get("total", 0.0))
            row["reward_energy"] = float(0.0 if np.isscalar(reward) else reward.get("energy", 0.0))
            row["reward_progress"] = float(0.0 if np.isscalar(reward) else reward.get("progress", 0.0))
            row["reward_safety"] = float(0.0 if np.isscalar(reward) else reward.get("safety", 0.0))
            row["E_pump_kJ"] = float(info.get("E_pump_kJ", np.nan))
            row["phase"] = int(info.get("phase", -1))
            row["P_pump"] = float(info.get("P_pump", np.nan))
            row["SoC"] = float(info.get("SoC", np.nan))
            row["soil_true"] = float(info.get("truth", {}).get("soil", np.nan))
            row["soil_est"] = float(belief["mean"].get("soil", np.nan))
            row["soil_std"] = float(belief["std"].get("soil", np.nan))
            row["acc_offset_true"] = float(info.get("truth", {}).get("acc_offset", np.nan))
            row["acc_offset_est"] = float(belief["mean"].get("acc_offset", np.nan))
            row["acc_offset_std"] = float(belief["std"].get("acc_offset", np.nan))
            row["act_v_boom"] = float(act[0])
            row["act_v_arm"] = float(act[1])
            row["act_v_bucket"] = float(act[2])
            row["act_pump"] = float(act[3])
            row["events_over"] = int(info.get("event_overpressure", 0))
            row["events_under"] = int(info.get("event_underpressure", 0))
            row["events_bounds"] = int(info.get("event_bounds", 0))
            step_rows.append(row)
        cyc = info.get("cycle_result", None)
        if cyc is not None:
            cyc_row = {}
            cyc_row["seed"] = seed
            cyc_row["controller"] = controller
            cyc_row["cycle_idx"] = int(cyc.get("idx", -1))
            cyc_row["E_m_kJ"] = float(cyc.get("E_m_kJ", np.nan))
            cyc_row["T_cycle_s"] = float(cyc.get("T_cycle_s", np.nan))
            cyc_row["precision_pct"] = float(cyc.get("precision_pct", np.nan))
            cyc_row["safety_events"] = int(cyc.get("safety_events", 0))
            cycle_rows.append(cyc_row)
        if done:
            break
        obs = obs_next
    return step_rows, cycle_rows, ep_ret

def train(cfg, outdir):
    make_dirs(outdir)
    seed = int(cfg.get("seed", 0))
    env, pf, ac, obs0, belief0 = make_env(cfg, seed)
    device = torch.device(cfg.get("device", "cpu"))
    ac.actor.to(device)
    ac.critic.to(device)
    mini_seconds = 60.0
    rate_hz = float(cfg.get("env", {}).get("rate_hz", 100.0))
    horizon_steps = int(mini_seconds*rate_hz)
    batch_size = 8
    gamma = 0.999
    ent_coef = 0.01
    ac.ent_coef = ent_coef
    actor_init = 1e-4
    critic_init = 3e-4
    actor_final = 1e-5
    critic_final = 3e-5
    total_windows = int(cfg.get("train", {}).get("windows", 40))
    best_score = -1e18
    best_path = os.path.join(outdir, "ckpts", "best.pt")
    val_snap = [None]*256
    planner = PlanningBaseline(tol=1e-4, max_iter=1000)
    ent_decay_started = False
    ent_decay_left = 5
    mem = deque(maxlen=batch_size)
    global_step = 0
    for w in range(total_windows):
        lr_scale = max(0.0, 1.0 - w/(total_windows-1 + 1e-9))
        for pg in ac.opt_actor.param_groups:
            pg["lr"] = actor_final + (actor_init-actor_final)*lr_scale
        for pg in ac.opt_critic.param_groups:
            pg["lr"] = critic_final + (critic_init-critic_final)*lr_scale
        s_rows, c_rows, ret = run_episode(env, pf, "BRL", cfg, horizon_steps, seed+w, mode="train", log_step=False, deterministic=False, projector=ac)
        obs_batch = []
        act_batch = []
        rew_batch = []
        done_batch = []
        next_obs_batch = []
        for i in range(len(s_rows)):
            pass
        traj_obs = []
        traj_act = []
        traj_rew = []
        traj_done = []
        traj_next_obs = []
        env.reset(seed=seed+w+100, cfg=cfg)
        pf.init_from_obs(obs0)
        belief = pf.estimate()
        for t in range(horizon_steps):
            a_t, logp, ent = ac.select_action(belief, deterministic=False)
            act = a_t.cpu().numpy().reshape(-1)
            obs_next, reward, done, info = env.step(act)
            pf.predict(act, cfg.get("env", {}).get("dt",0.01))
            pf.update(obs_next)
            belief_next = pf.estimate()
            traj_obs.append(to_vec_from_belief(belief))
            traj_act.append(act)
            traj_rew.append(float(reward if np.isscalar(reward) else reward.get("total", 0.0)))
            traj_done.append(0.0 if not done else 1.0)
            traj_next_obs.append(to_vec_from_belief(belief_next))
            belief = belief_next
            global_step += 1
            if done:
                break
        obs_tensor = torch.tensor(np.array(traj_obs), dtype=torch.float32, device=device)
        act_tensor = torch.tensor(np.array(traj_act), dtype=torch.float32, device=device)
        rew_tensor = torch.tensor(np.array(traj_rew), dtype=torch.float32, device=device)
        done_tensor = torch.tensor(np.array(traj_done), dtype=torch.float32, device=device)
        next_obs_tensor = torch.tensor(np.array(traj_next_obs), dtype=torch.float32, device=device)
        batch = {"obs":obs_tensor, "act":act_tensor, "rew":rew_tensor, "done":done_tensor, "next_obs":next_obs_tensor}
        stats = ac.losses(batch)
        if not ent_decay_started:
            res = planner.value_iteration(val_snap)
            if res["converged"]:
                ent_decay_started = True
        if ent_decay_started and ent_decay_left > 0:
            ac.ent_coef = ent_coef * (ent_decay_left/5.0)
            ent_decay_left -= 1
        score = -float(np.mean(traj_rew))
        if score > best_score:
            best_score = score
            ac.save(best_path)
    ac.save(os.path.join(outdir, "ckpts", "last.pt"))
    env.close()
    return best_path

def evaluate(cfg, outdir, mode="eval", ablate=None):
    make_dirs(outdir)
    seeds = list(range(int(cfg.get("eval", {}).get("seeds", 20))))
    rate_hz = float(cfg.get("env", {}).get("rate_hz", 100.0))
    T_min = 20.0
    horizon_steps = int(T_min*60.0*rate_hz)
    step_logs = []
    cycle_logs = []
    for controller in (["Traditional","Operator"] if mode!="ablate" else ["BRL"]):
        for seed in seeds:
            env, pf, ac, obs0, belief0 = make_env(cfg, seed+1000*(1 if controller!="BRL" else 2))
            if controller == "BRL":
                ck = os.path.join(outdir, "ckpts", "best.pt")
                if os.path.exists(ck):
                    ac.load(ck, map_location=cfg.get("device","cpu"))
            step_rows, cycle_rows, ret = run_episode(env, pf, "BRL" if controller=="BRL" else ("Traditional" if controller=="Traditional" else "Operator"), cfg, horizon_steps, seed, mode="eval", log_step=True, deterministic=True, projector=ac, ablate=ablate)
            step_logs.extend(step_rows)
            cycle_logs.extend(cycle_rows)
            env.close()
    df_step = pd.DataFrame(step_logs)
    df_cycle = pd.DataFrame(cycle_logs)
    try:
        import pyarrow as pa
        import pyarrow.parquet as pq
        table = pa.Table.from_pandas(df_step)
        pq.write_table(table, os.path.join(outdir, "logs", "per_step.parquet"))
    except Exception:
        df_step.to_csv(os.path.join(outdir, "logs", "per_step.csv"), index=False)
    df_cycle.to_csv(os.path.join(outdir, "logs", "per_cycle.csv"), index=False)
    summary_rows = []
    for controller in df_cycle["controller"].unique():
        sub = df_cycle[df_cycle["controller"]==controller]
        Em = sub["E_m_kJ"].dropna().values
        Tc = sub["T_cycle_s"].dropna().values
        Pr = sub["precision_pct"].dropna().values
        Se = sub["safety_events"].dropna().values
        if Em.size==0:
            continue
        row = {}
        row["controller"] = controller
        row["E_m_mean"] = float(np.mean(Em))
        row["E_m_ci95"] = float(1.96*np.std(Em, ddof=1)/math.sqrt(max(1,len(Em))))
        row["T_cycle_mean"] = float(np.mean(Tc)) if Tc.size>0 else np.nan
        row["T_cycle_ci95"] = float(1.96*np.std(Tc, ddof=1)/math.sqrt(max(1,len(Tc)))) if Tc.size>0 else np.nan
        row["Precision_mean"] = float(np.mean(Pr)) if Pr.size>0 else np.nan
        row["Precision_ci95"] = float(1.96*np.std(Pr, ddof=1)/math.sqrt(max(1,len(Pr)))) if Pr.size>0 else np.nan
        row["Safety_per100"] = float(100.0*np.sum(Se)/max(1,len(Se)))
        summary_rows.append(row)
    df_sum = pd.DataFrame(summary_rows)
    df_sum.to_csv(os.path.join(outdir, "logs", "summary.csv"), index=False)
    ctrl_list = df_cycle["controller"].unique().tolist()
    if "Traditional" in ctrl_list and "BRL" in ctrl_list:
        a = df_cycle[df_cycle["controller"]=="Traditional"]["E_m_kJ"].dropna().values
        b = df_cycle[df_cycle["controller"]=="BRL"]["E_m_kJ"].dropna().values
        if len(a)>1 and len(b)>1:
            ma, mb = np.mean(a), np.mean(b)
            sa, sb = np.std(a, ddof=1), np.std(b, ddof=1)
            t_stat = (ma-mb)/math.sqrt((sa*sa/len(a))+(sb*sb/len(b))+1e-12)
            p = two_sided_p_from_t(t_stat, len(a), len(b))
            d = compute_effect_size(a, b)
            with open(os.path.join(outdir, "logs", "tests.txt"), "w", encoding="utf-8") as f:
                f.write(f"E_m Traditional vs BRL: t={t_stat:.3f}, p={p:.3g}, d={d:.3f}\n")
    return df_step, df_cycle, df_sum

def eval_extension(cfg, outdir, controller="BRL", cycles=120):
    make_dirs(outdir)
    seed = int(cfg.get("seed", 0)) + 4242
    env, pf, ac, obs0, belief0 = make_env(cfg, seed)
    if controller=="BRL":
        ck = os.path.join(outdir, "ckpts", "best.pt")
        if os.path.exists(ck):
            ac.load(ck, map_location=cfg.get("device","cpu"))
    rate_hz = float(cfg.get("env", {}).get("rate_hz", 100.0))
    max_steps = int(60.0*60.0*rate_hz)
    step_rows = []
    cycle_rows = []
    obs = env.reset(seed=seed, cfg=cfg)
    belief = pf.init_from_obs(obs)
    cyc_cnt = 0
    t = 0
    while cyc_cnt < cycles and t < max_steps:
        a_torch, logp, ent = (ac.select_action(belief, deterministic=True) if controller=="BRL" else (None,None,None))
        if controller=="BRL":
            act = a_torch.cpu().numpy().reshape(-1)
        elif controller=="Traditional":
            act = baseline_action_fixed(belief)
        else:
            act = baseline_action_operator(belief)
        obs_next, reward, done, info = env.step(act)
        pf.predict(act, cfg.get("env", {}).get("dt",0.01))
        pf.update(obs_next)
        belief = pf.estimate()
        cyc = info.get("cycle_result", None)
        if cyc is not None:
            cyc_cnt += 1
            cycle_rows.append({"cycle":cyc_cnt, "E_m_kJ":float(cyc.get("E_m_kJ", np.nan)), "controller":controller})
        t += 1
        if done:
            break
    df = pd.DataFrame(cycle_rows)
    df.to_csv(os.path.join(outdir, "logs", f"extension_{controller}.csv"), index=False)
    plt.figure(figsize=(6,4), dpi=160)
    if len(df)>0:
        x = np.arange(1, len(df["cycle"])+1)
        y = df["E_m_kJ"].values
        plt.plot(x, y)
    plt.xlabel("Cycle")
    plt.ylabel("Energy per cycle (kJ)")
    plt.title(f"Energy over {cycles} cycles - {controller}")
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, "figs", f"figC_energy_{controller}.png"))
    plt.close()
    env.close()
    return df

def plot_adaptation(cfg, outdir, df_step):
    make_dirs(outdir)
    df = df_step[df_step["controller"]=="BRL"].copy()
    if len(df)==0:
        return
    plt.figure(figsize=(6,4), dpi=160)
    t = df["time_s"].values
    y = df["soil_est"].values
    y2 = df["soil_true"].values
    s = df["soil_std"].values
    plt.plot(t, y)
    plt.fill_between(t, y-2*s, y+2*s, alpha=0.2)
    if np.isfinite(y2).any():
        plt.plot(t, y2)
    plt.xlabel("Time (s)")
    plt.ylabel("Soil resistance")
    plt.title("Estimated vs. true soil multiplier")
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, "figs", "figA_soil.png"))
    plt.close()
    plt.figure(figsize=(6,4), dpi=160)
    t = df["time_s"].values
    y = df["acc_offset_est"].values
    y2 = df["acc_offset_true"].values
    s = df["acc_offset_std"].values
    plt.plot(t, y)
    plt.fill_between(t, y-2*s, y+2*s, alpha=0.2)
    if np.isfinite(y2).any():
        plt.plot(t, y2)
    plt.xlabel("Time (s)")
    plt.ylabel("Accumulator precharge offset")
    plt.title("Estimated vs. true precharge offset")
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, "figs", "figA_acc.png"))
    plt.close()
    plt.figure(figsize=(6,4), dpi=160)
    t = df["time_s"].values
    p = df["P_pump"].values
    a = df["act_v_boom"].values
    plt.plot(t, p)
    plt.xlabel("Time (s)")
    plt.ylabel("Pump pressure (MPa)")
    plt.title("Pump pressure setpoint and valve opening")
    plt.twinx()
    plt.plot(t, a)
    plt.ylabel("Valve opening")
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, "figs", "figA_actions.png"))
    plt.close()

def plot_main_results(cfg, outdir, df_cycle, df_sum):
    make_dirs(outdir)
    g = df_sum.set_index("controller")
    if len(g)==0:
        return
    labels = g.index.tolist()
    x = np.arange(len(labels))
    width = 0.25
    fig = plt.figure(figsize=(7,4), dpi=160)
    ax = plt.gca()
    y1 = [g.loc[k,"E_m_mean"] for k in labels]
    e1 = [g.loc[k,"E_m_ci95"] for k in labels]
    y2 = [g.loc[k,"T_cycle_mean"] if "T_cycle_mean" in g.columns else np.nan for k in labels]
    e2 = [g.loc[k,"T_cycle_ci95"] if "T_cycle_ci95" in g.columns else np.nan for k in labels]
    y3 = [g.loc[k,"Precision_mean"] if "Precision_mean" in g.columns else np.nan for k in labels]
    e3 = [g.loc[k,"Precision_ci95"] if "Precision_ci95" in g.columns else np.nan for k in labels]
    ax.bar(x-0.3, y1, width, yerr=e1)
    ax.bar(x, y2, width, yerr=e2)
    ax.bar(x+0.3, y3, width, yerr=e3)
    ax.set_xticks(x)
    ax.set_xticklabels(labels)
    ax.set_ylabel("Value")
    ax.set_title("Energy, Cycle Time, Precision")
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, "figs", "figB_bars.png"))
    plt.close()
    for ctrl in df_cycle["controller"].unique():
        sub = df_cycle[df_cycle["controller"]==ctrl]
        if len(sub)==0:
            continue
        plt.figure(figsize=(6,4), dpi=160)
        plt.hist(sub["E_m_kJ"].dropna().values, bins=20, density=True)
        plt.xlabel("Energy per cycle (kJ)")
        plt.ylabel("Density")
        plt.title(f"Residual-like distribution {ctrl}")
        plt.tight_layout()
        plt.savefig(os.path.join(outdir, "figs", f"figD_hist_{ctrl}.png"))
        plt.close()

def gpu_timing_plot(outdir, device="cuda"):
    make_dirs(outdir)
    counts = [8,16,32,64,128]
    times = []
    for n in counts:
        t0 = time.time()
        x = torch.randn(n, 256, device=(device if torch.cuda.is_available() else "cpu"))
        y = torch.randn(n, 256, device=(device if torch.cuda.is_available() else "cpu"))
        for _ in range(100):
            z = torch.relu(x @ y.T)
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        t1 = time.time()
        times.append((t1-t0)*1000.0)
    plt.figure(figsize=(6,4), dpi=160)
    plt.bar(np.arange(len(counts)), times)
    plt.xticks(np.arange(len(counts)), [str(c) for c in counts])
    plt.xlabel("Number of snapshots")
    plt.ylabel("Inference time (ms)")
    plt.title("Snapshot count vs. inference time")
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, "figs", "figGPU_timing.png"))
    plt.close()

def run_ablation(cfg, outdir):
    make_dirs(outdir)
    modes = [{"no_pf":True}, {"no_soc_reward":True}, {"no_proj":True}]
    for i, ab in enumerate(modes):
        df_step, df_cycle, df_sum = evaluate(cfg, os.path.join(outdir, f"ablate_{i}"), mode="ablate", ablate=ab)
        plot_adaptation(cfg, os.path.join(outdir, f"ablate_{i}"), df_step)
        plot_main_results(cfg, os.path.join(outdir, f"ablate_{i}"), df_cycle, df_sum)

def run_all(cfg_path):
    cfg = load_cfg(cfg_path)
    outdir = cfg.get("outdir", "outputs")
    best = train(cfg, outdir)
    df_step, df_cycle, df_sum = evaluate(cfg, outdir, mode="eval")
    plot_adaptation(cfg, outdir, df_step)
    plot_main_results(cfg, outdir, df_cycle, df_sum)
    eval_extension(cfg, outdir, controller="BRL", cycles=120)
    gpu_timing_plot(outdir, device=cfg.get("device","cuda"))

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--cfg", type=str, required=True)
    parser.add_argument("--mode", type=str, choices=["train","eval","ablate","all"], default="all")
    args = parser.parse_args()
    cfg = load_cfg(args.cfg)
    outdir = cfg.get("outdir", "outputs")
    if args.mode == "train":
        train(cfg, outdir)
    elif args.mode == "eval":
        df_step, df_cycle, df_sum = evaluate(cfg, outdir, mode="eval")
        plot_adaptation(cfg, outdir, df_step)
        plot_main_results(cfg, outdir, df_cycle, df_sum)
        eval_extension(cfg, outdir, controller="BRL", cycles=120)
        gpu_timing_plot(outdir, device=cfg.get("device","cuda"))
    elif args.mode == "ablate":
        run_ablation(cfg, outdir)
    else:
        run_all(args.cfg)

if __name__ == "__main__":
    main()
