import numpy as np
import pandas as pd
import pathlib
import json
import math
import time
from collections import deque

class RingBuffer:
    def __init__(self, capacity, keys):
        self.capacity = int(capacity)
        self.keys = list(keys)
        self.data = {k: deque(maxlen=self.capacity) for k in self.keys}
    def append(self, item):
        for k in self.keys:
            self.data[k].append(item.get(k, None))
    def to_dataframe(self):
        n = min(len(v) for v in self.data.values()) if self.data else 0
        if n == 0:
            return pd.DataFrame(columns=self.keys)
        trimmed = {k: list(self.data[k])[-n:] for k in self.keys}
        return pd.DataFrame(trimmed)

def _linmap(x, x0, x1, y0, y1):
    if x1 == x0:
        return y0
    t = (x - x0) / (x1 - x0)
    return y0 + t * (y1 - y0)

def _clamp(x, lo, hi):
    return lo if x < lo else hi if x > hi else x

def _rand(seed_seq, n):
    rs = np.random.RandomState(seed_seq)
    return rs.rand(n)

def _ar1_noise(n, sigma, tau_steps, seed):
    rs = np.random.RandomState(seed)
    a = math.exp(-1.0 / max(1, tau_steps))
    eps = rs.randn(n) * sigma * math.sqrt(1 - a * a)
    y = np.zeros(n)
    for i in range(1, n):
        y[i] = a * y[i - 1] + eps[i]
    return y

class ExcavatorEnv:
    def __init__(self, cfg=None):
        self.cfg = cfg or {}
        self.dt = 0.01
        self.rate_hz = 100.0
        self.total_seconds = 20 * 60.0
        self.total_steps = int(self.total_seconds * self.rate_hz)
        self.extended_cycles = 120
        self.simulink_enabled = bool(self.cfg.get("simulink", {}).get("enabled", False))
        self.paths = {k: self.cfg.get("paths", {}).get(k, None) for k in ["uncertainty_dir", "logs_dir"]}
        self.state = {}
        self.obs = {}
        self.info = {}
        self.phase = "lower"
        self.phase_time = 0.0
        self.cycle_index = 0
        self.done = False
        self.step_idx = 0
        self.seed = 0
        self.ring = RingBuffer(100000, keys=["t","cycle","phase","P_pump","Q_pump","SoC","soil","oil_T","mu","energy_J","reward_total","reward_prog","reward_energy","reward_safety","overpressure_req","underpressure_req","acc_req","softlimit_req"])
        self.angles_soft = np.array([[-10.0, 70.0], [-40.0, 55.0], [-110.0, 35.0]])
        self.spool_rate_limit = 0.02
        self.pump_rate_limit_mpa = 0.5
        self.rail_lo = 5.0
        self.rail_hi = 35.0
        self.acc_lo = 12.0
        self.acc_hi = 28.0
        self.pc_drop = 1.0
        self.section_gains = np.array([1.0, 1.0, 1.0])
        self.port_relief = np.array([32.0, 32.0, 32.0])
        self.acc_volume_l = 10.0
        self.prev_action = np.zeros(4, dtype=np.float32)

    def _load_or_generate_trajectories(self, seed):
        n = self.total_steps
        base = pathlib.Path(self.paths["uncertainty_dir"]) if self.paths["uncertainty_dir"] else None
        def _load(name, gen):
            if base:
                p = base / name
                if p.exists():
                    s = pd.read_csv(p)
                    v = s.iloc[:n, -1].to_numpy()
                    if len(v) < n:
                        v = np.pad(v, (0, n - len(v)), mode="edge")
                    return v[:n]
            return gen()
        ambient = _load("env_temperature.csv", lambda: np.concatenate([np.linspace(18.0, 34.0, int(12*60*self.rate_hz)), np.full(n - int(12*60*self.rate_hz), 34.0)]))
        oil_T = np.zeros(n)
        oil_T[0] = max(35.0, ambient[0])
        tau = int(180.0 * self.rate_hz)
        for i in range(1, n):
            oil_T[i] = oil_T[i-1] + (ambient[i] - oil_T[i-1]) * (1.0 / max(1, tau))
            oil_T[i] = _clamp(oil_T[i], 35.0, 65.0)
        mu = _load("viscosity.csv", lambda: _linmap(oil_T, 35.0, 65.0, 0.050, 0.020))
        seg_lengths_s = [240,180,300,240,240]
        seg_means = [0.9,1.2,1.0,1.3,0.8]
        seg_steps = [int(s * self.rate_hz) for s in seg_lengths_s]
        soil = np.zeros(n)
        idx = 0
        rs = np.random.RandomState(seed+123)
        for m, L in zip(seg_means, seg_steps):
            L = min(L, n - idx)
            base_seg = np.full(L, m)
            noise = _ar1_noise(L, sigma=0.1*m, tau_steps=int(0.5*self.rate_hz), seed=rs.randint(1,1_000_000))
            soil[idx:idx+L] = np.clip(base_seg + noise, 0.5*m, 1.5*m)
            idx += L
            if idx >= n:
                break
        if idx < n:
            soil[idx:] = soil[idx-1]
        valve_gain = _load("wear_valve.csv", lambda: np.linspace(1.0, 0.94, n))
        cyl_leak = _load("wear_leak.csv", lambda: np.linspace(0.0, 0.4, n))
        pump_eta = _load("wear_pump.csv", lambda: np.linspace(1.0, 0.97, n))
        fs = 40.0
        def rw(total_span, std_step, seed_local, lo, hi):
            rs2 = np.random.RandomState(seed_local)
            w = np.zeros(n)
            for i in range(1, n):
                w[i] = w[i-1] + rs2.randn() * std_step
            scale = total_span / max(1e-9, (np.percentile(np.abs(w), 99) + 1e-6))
            w = np.clip(w * scale, -total_span, total_span)
            w = np.clip(w, lo, hi)
            return w
        pb_pump = _load("press_bias_pump.csv", lambda: rw(0.002*fs, 1e-4, seed+1, -0.002*fs, 0.002*fs))
        pb_ports = []
        for k in range(6):
            pb_ports.append(_load(f"press_bias_port{k}.csv", lambda s=seed+k+2: rw(0.002*fs, 1e-4, s, -0.002*fs, 0.002*fs)))
        pb_ports = np.stack(pb_ports, axis=1)
        enc_bias = _load("pos_bias.csv", lambda: rw(0.5, 1e-3, seed+99, -0.5, 0.5))
        acc_precharge_offset = self.cfg.get("acc_precharge_offset", None)
        if acc_precharge_offset is None:
            rs3 = np.random.RandomState(seed+777)
            acc_precharge_offset = rs3.uniform(-0.10, 0.10)
        valve_hyst = self.cfg.get("valve_hysteresis", 0.03)
        valve_dead = self.cfg.get("valve_deadband", 0.02)
        return {"ambient": ambient, "oil_T": oil_T, "mu": mu, "soil": soil, "valve_gain": valve_gain, "cyl_leak": cyl_leak, "pump_eta": pump_eta, "pb_pump": pb_pump, "pb_ports": pb_ports, "enc_bias": enc_bias, "acc_precharge_offset": acc_precharge_offset, "valve_hyst": valve_hyst, "valve_dead": valve_dead}

    def reset(self, seed=0, cfg=None):
        if cfg is not None:
            self.cfg = cfg
        self.seed = int(seed)
        self.step_idx = 0
        self.done = False
        self.cycle_index = 0
        self.phase = "lower"
        self.phase_time = 0.0
        self.prev_action = np.zeros(4, dtype=np.float32)
        self.traj = self._load_or_generate_trajectories(self.seed)
        self.state = {}
        self.state["angles"] = np.array([30.0, 0.0, -20.0], dtype=np.float32)
        self.state["ang_vel"] = np.zeros(3, dtype=np.float32)
        self.state["press_A"] = np.full(3, 18.0, dtype=np.float32)
        self.state["press_B"] = np.full(3, 16.0, dtype=np.float32)
        self.state["P_pump"] = 20.0
        self.state["Q_pump"] = 0.0
        self.state["SoC"] = 0.5
        self.state["oil_T"] = float(self.traj["oil_T"][0])
        self.state["mu"] = float(self.traj["mu"][0])
        self.state["soil"] = float(self.traj["soil"][0])
        self.state["valve_gain_scale"] = float(self.traj["valve_gain"][0])
        self.state["cyl_leak_scale"] = float(self.traj["cyl_leak"][0])
        self.state["pump_eta"] = float(self.traj["pump_eta"][0])
        self.state["pb_pump"] = float(self.traj["pb_pump"][0])
        self.state["pb_ports"] = self.traj["pb_ports"][0].astype(np.float32)
        self.state["enc_bias"] = float(self.traj["enc_bias"][0])
        self.state["acc_offset"] = float(self.traj["acc_precharge_offset"])
        self.energy_acc_J = 0.0
        self.energy_cycle_J = 0.0
        self.reward_cycle = {"energy":0.0,"progress":0.0,"safety":0.0}
        self.events_cycle = {"overpressure_req":0,"underpressure_req":0,"acc_req":0,"softlimit_req":0,"port_relief":0}
        self.ring = RingBuffer(100000, keys=self.ring.keys)
        return self._observe()

    def _observe(self):
        enc_noise = 0.15
        press_noise_pc = 0.005
        angles_obs = self.state["angles"] + self.state["enc_bias"] + np.random.randn(3).astype(np.float32)*enc_noise
        ports = np.concatenate([self.state["press_A"], self.state["press_B"]], axis=0)
        press_noise = np.random.randn(6).astype(np.float32) * press_noise_pc * 40.0
        press_obs_ports = ports + self.state["pb_ports"] + press_noise
        pump_press_obs = self.state["P_pump"] + self.state["pb_pump"] + np.random.randn()*press_noise_pc*40.0
        pump_rpm_obs = 1800.0 + 20.0*np.random.randn()
        pump_disp_obs = _clamp(_linmap(self.state["Q_pump"], 0.0, 220.0, 0.0, 1.0), 0.0, 1.0)
        oil_T_obs = self.state["oil_T"]
        obs = {"angles": angles_obs.astype(np.float32), "press_ports": press_obs_ports.astype(np.float32), "pump_press": float(pump_press_obs), "pump_rpm": float(pump_rpm_obs), "pump_disp": float(pump_disp_obs), "oil_T": float(oil_T_obs)}
        self.obs = obs
        return obs

    def _phase_schedule(self):
        return {"lower":1.5,"cut":2.0,"lift":1.5,"dump":1.0}

    def _advance_phase(self):
        order = ["lower","cut","lift","dump"]
        i = order.index(self.phase)
        if i < 3:
            self.phase = order[i+1]
        else:
            self.phase = "lower"
            self.cycle_index += 1
            self.energy_cycle_J = 0.0
            self.reward_cycle = {"energy":0.0,"progress":0.0,"safety":0.0}
            self.events_cycle = {"overpressure_req":0,"underpressure_req":0,"acc_req":0,"softlimit_req":0,"port_relief":0}
        self.phase_time = 0.0

    def _desired_motion(self):
        if self.phase == "lower":
            target_vel = np.array([-8.0, -3.0, 0.0])
        elif self.phase == "cut":
            target_vel = np.array([0.0, 2.0, 4.0])
        elif self.phase == "lift":
            target_vel = np.array([8.0, 3.0, 0.0])
        else:
            target_vel = np.array([0.0, 0.0, 12.0])
        return target_vel

    def _compute_section_loads(self, soil):
        if self.phase == "lower":
            base = np.array([10.0, 8.0, 6.0])
        elif self.phase == "cut":
            base = np.array([20.0, 22.0, 18.0])
        elif self.phase == "lift":
            base = np.array([26.0, 18.0, 12.0])
        else:
            base = np.array([12.0, 10.0, 8.0])
        load = base * soil
        return np.clip(load, 5.0, 32.0)

    def _acc_pressure_from_soc(self, soc):
        return _linmap(soc, 0.0, 1.0, self.acc_lo, self.acc_hi)

    def _soc_from_pressure(self, p):
        return _linmap(p, self.acc_lo, self.acc_hi, 0.0, 1.0)

    def _spool_with_hysteresis(self, cmd, prev_cmd, dead, hyst):
        out = np.zeros_like(cmd)
        for i in range(len(cmd)):
            c = cmd[i]
            p = prev_cmd[i]
            if abs(c) < dead:
                c_eff = 0.0
            else:
                if c > p + hyst:
                    c_eff = c - hyst*0.5
                elif c < p - hyst:
                    c_eff = c + hyst*0.5
                else:
                    c_eff = p
            out[i] = _clamp(c_eff, 0.0, 1.0)
        return out

    def step(self, action):
        if self.done:
            return self._observe(), 0.0, True, {}
        raw_action = np.array(action, dtype=np.float32)
        action_rate = np.copy(raw_action)
        action_rate[:3] = np.clip(raw_action[:3], self.prev_action[:3] - self.spool_rate_limit, self.prev_action[:3] + self.spool_rate_limit)
        action_rate[3] = np.clip(raw_action[3], self.prev_action[3] - self.pump_rate_limit_mpa/35.0, self.prev_action[3] + self.pump_rate_limit_mpa/35.0)
        requests = {"overpressure_req":0,"underpressure_req":0,"acc_req":0,"softlimit_req":0,"port_relief":0}
        requests["softlimit_req"] = int(np.any((self.state["angles"] < self.angles_soft[:,0]+1.0) | (self.state["angles"] > self.angles_soft[:,1]-1.0)))
        spools = action_rate[:3]
        spools = self._spool_with_hysteresis(spools, self.prev_action[:3], self.traj["valve_dead"], self.traj["valve_hyst"])
        pump_sp_raw = action_rate[3]
        pump_target_mpa = _clamp(pump_sp_raw * self.rail_hi, self.rail_lo, self.rail_hi)
        soil = float(self.traj["soil"][self.step_idx])
        self.state["soil"] = soil
        self.state["oil_T"] = float(self.traj["oil_T"][self.step_idx])
        self.state["mu"] = float(self.traj["mu"][self.step_idx])
        self.state["valve_gain_scale"] = float(self.traj["valve_gain"][self.step_idx])
        self.state["cyl_leak_scale"] = float(self.traj["cyl_leak"][self.step_idx])
        self.state["pump_eta"] = float(self.traj["pump_eta"][self.step_idx])
        self.state["pb_pump"] = float(self.traj["pb_pump"][self.step_idx])
        self.state["pb_ports"] = self.traj["pb_ports"][self.step_idx].astype(np.float32)
        self.state["enc_bias"] = float(self.traj["enc_bias"][self.step_idx])
        loads = self._compute_section_loads(soil)
        valve_gain = self.section_gains * self.state["valve_gain_scale"]
        section_flow = valve_gain * spools * np.maximum(0.0, pump_target_mpa - loads) / max(1e-6, self.pc_drop + 1e-3)
        section_flow = np.clip(section_flow, 0.0, 120.0)
        port_relief = loads > self.port_relief
        requests["port_relief"] = int(np.any(port_relief))
        Q_total = float(np.sum(section_flow))
        Q_total = min(Q_total, 220.0)
        self.state["Q_pump"] = Q_total
        p_dyn = self.state["P_pump"] + (pump_target_mpa - self.state["P_pump"]) * 0.5
        p_dyn = _clamp(p_dyn, self.rail_lo, self.rail_hi)
        self.state["P_pump"] = p_dyn
        acc_p = self._acc_pressure_from_soc(self.state["SoC"])
        if acc_p < self.acc_lo + 0.2:
            requests["acc_req"] = 1
        if acc_p > self.acc_hi - 0.2:
            requests["acc_req"] = 1
        if self.state["P_pump"] > self.rail_hi - 0.2:
            requests["overpressure_req"] = 1
        if self.state["P_pump"] < self.rail_lo + 0.2:
            requests["underpressure_req"] = 1
        desired_vel = self._desired_motion()
        eff = 1.0 - 0.3 * np.clip((loads - 18.0)/18.0, 0.0, 1.0)
        vel = desired_vel * eff * (self.state["P_pump"]/25.0) * self.state["valve_gain_scale"]
        vel += np.random.randn(3)*0.1
        angles = self.state["angles"] + vel * self.dt
        for j in range(3):
            if angles[j] < self.angles_soft[j,0]:
                angles[j] = self.angles_soft[j,0]
                requests["softlimit_req"] = 1
            if angles[j] > self.angles_soft[j,1]:
                angles[j] = self.angles_soft[j,1]
                requests["softlimit_req"] = 1
        self.state["ang_vel"] = vel.astype(np.float32)
        self.state["angles"] = angles.astype(np.float32)
        press_A = np.minimum(loads + 2.0, self.rail_hi)
        press_B = np.maximum(self.rail_lo, loads - 2.0)
        leak = self.state["cyl_leak_scale"] * 0.02
        press_A = np.clip(press_A - leak, self.rail_lo, self.rail_hi)
        press_B = np.clip(press_B - leak, self.rail_lo, self.rail_hi)
        self.state["press_A"] = 0.8*self.state["press_A"] + 0.2*press_A
        self.state["press_B"] = 0.8*self.state["press_B"] + 0.2*press_B
        soc = self.state["SoC"]
        if self.phase == "lower":
            soc += 0.002 * np.mean(spools) * (self.state["P_pump"]/28.0)
        elif self.phase == "lift":
            soc -= 0.0025 * np.mean(spools) * (loads[0]/26.0)
        soc = _clamp(soc, 0.0, 1.0)
        acc_p_new = self._acc_pressure_from_soc(soc)
        if acc_p_new >= self.acc_hi:
            soc = self._soc_from_pressure(self.acc_hi)
            requests["acc_req"] = 1
        if acc_p_new <= self.acc_lo:
            soc = self._soc_from_pressure(self.acc_lo)
            requests["acc_req"] = 1
        self.state["SoC"] = soc
        liters = self.state["Q_pump"] * (self.dt/60.0)
        energy_J = self.state["P_pump"] * liters * 1000.0
        self.energy_acc_J += energy_J
        self.energy_cycle_J += energy_J
        self.phase_time += self.dt
        sched = self._phase_schedule()
        progress = 0.0
        if self.phase_time >= sched[self.phase]:
            self._advance_phase()
            progress = 1.0
        reward_energy = -energy_J/1000.0
        reward_prog = 1.0 if progress > 0.0 else 0.0
        safety_pen = -1.0*(requests["overpressure_req"]+requests["underpressure_req"]+requests["acc_req"]+requests["softlimit_req"]+requests["port_relief"])
        reward = reward_energy + reward_prog + safety_pen*0.1
        self.reward_cycle["energy"] += reward_energy
        self.reward_cycle["progress"] += reward_prog
        self.reward_cycle["safety"] += safety_pen
        for k in self.events_cycle:
            if k in requests:
                self.events_cycle[k] += requests[k]
        self.prev_action[:3] = spools
        self.prev_action[3] = pump_target_mpa / self.rail_hi
        obs = self._observe()
        info = {}
        info["state_true"] = {k: (float(v) if isinstance(v, (np.floating, float)) else (v.tolist() if isinstance(v, np.ndarray) else v)) for k,v in self.state.items()}
        info["events"] = requests
        info["step_energy_J"] = float(energy_J)
        info["cycle_energy_J"] = float(self.energy_cycle_J)
        info["phase"] = self.phase
        info["cycle_index"] = int(self.cycle_index)
        info["reward_breakdown"] = {"energy": float(reward_energy), "progress": float(reward_prog), "safety": float(safety_pen)}
        t = self.step_idx * self.dt
        self.ring.append({"t":t,"cycle":self.cycle_index,"phase":0 if self.phase=="lower" else 1 if self.phase=="cut" else 2 if self.phase=="lift" else 3,"P_pump":self.state["P_pump"],"Q_pump":self.state["Q_pump"],"SoC":self.state["SoC"],"soil":self.state["soil"],"oil_T":self.state["oil_T"],"mu":self.state["mu"],"energy_J":energy_J,"reward_total":reward,"reward_prog":reward_prog,"reward_energy":reward_energy,"reward_safety":safety_pen,"overpressure_req":requests["overpressure_req"],"underpressure_req":requests["underpressure_req"],"acc_req":requests["acc_req"],"softlimit_req":requests["softlimit_req"]})
        self.step_idx += 1
        if self.step_idx >= self.total_steps:
            self.done = True
        return obs, float(reward), bool(self.done), info

    def render(self, mode=None):
        return None

    def close(self):
        return None

if __name__ == "__main__":
    cfg = {"paths":{"uncertainty_dir":None,"logs_dir":"./logs"}, "simulink":{"enabled":False}}
    env = ExcavatorEnv(cfg)
    obs = env.reset(seed=42)
    total_r = 0.0
    a = np.array([0.2,0.2,0.2,0.6], dtype=np.float32)
    for _ in range(2000):
        obs, r, d, info = env.step(a)
        total_r += r
        if d:
            break
    df = env.ring.to_dataframe()
    p = pathlib.Path(cfg["paths"]["logs_dir"])
    p.mkdir(parents=True, exist_ok=True)
    df.to_csv(p/"per_step_env.csv", index=False)
    print(total_r)
