
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import random

class QNet(nn.Module):
    def __init__(self, in_dim, out_dim, hidden=[256,256]):
        super().__init__()
        layers = []
        last = in_dim
        for h in hidden:
            layers += [nn.Linear(last, h), nn.ReLU()]
            last = h
        layers += [nn.Linear(last, out_dim)]
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

class Replay:
    def __init__(self, cap=50000, prioritized=True, alpha=0.6):
        self.buf = deque(maxlen=cap)
        self.prioritized = prioritized
        self.alpha = alpha
        self.pr = deque(maxlen=cap)

    def push(self, *exp, td_error=1.0):
        self.buf.append(exp)
        self.pr.append(abs(td_error)+1e-5)

    def sample(self, batch):
        if self.prioritized and len(self.buf) == len(self.pr):
            import numpy as np
            probs = np.array(self.pr, dtype=np.float32) ** self.alpha
            probs /= probs.sum()
            idx = np.random.choice(len(self.buf), size=batch, p=probs)
            items = [self.buf[i] for i in idx]
            weights = (1.0 / (len(self.buf) * probs[idx]))
            weights /= weights.max()
            return items, torch.tensor(weights, dtype=torch.float32)
        else:
            items = random.sample(self.buf, batch)
            return items, torch.ones(batch, dtype=torch.float32)

class DQNAgent:
    def __init__(self, obs_dim, act_dim, cfg):
        self.q = QNet(obs_dim, act_dim, hidden=cfg["hidden_layers"])
        self.tgt = QNet(obs_dim, act_dim, hidden=cfg["hidden_layers"])
        self.tgt.load_state_dict(self.q.state_dict())
        self.opt = optim.Adam(self.q.parameters(), lr=cfg["lr"])
        self.gamma = cfg["gamma"]
        self.batch = cfg["batch_size"]
        self.target_update = cfg["target_update"]
        self.double = cfg.get("double_dqn", True)
        self.replay = Replay(cfg["buffer_size"], prioritized=cfg.get("prioritized_replay", True))
        self.step_count = 0
        self.eps = cfg["epsilon"]["start"]
        self.eps_end = cfg["epsilon"]["end"]
        self.eps_decay = cfg["epsilon"]["decay"]

    def act(self, obs, act_space):
        if np.random.rand() < self.eps:
            return np.random.randint(0, act_space)
        with torch.no_grad():
            q = self.q(torch.tensor(obs, dtype=torch.float32).unsqueeze(0))
        return int(q.argmax().item())

    def push(self, exp, td_error=1.0):
        self.replay.push(*exp, td_error=td_error)

    def update(self):
        if len(self.replay.buf) < self.batch:
            return 0.0
        batch, weights = self.replay.sample(self.batch)
        obs, act, rew, nxt, done = zip(*batch)
        obs = torch.tensor(obs, dtype=torch.float32)
        act = torch.tensor(act, dtype=torch.long).unsqueeze(1)
        rew = torch.tensor(rew, dtype=torch.float32).unsqueeze(1)
        nxt = torch.tensor(nxt, dtype=torch.float32)
        done = torch.tensor(done, dtype=torch.float32).unsqueeze(1)
        weights = weights.unsqueeze(1)

        q_values = self.q(obs).gather(1, act)
        with torch.no_grad():
            if self.double:
                next_actions = self.q(nxt).argmax(1, keepdim=True)
                next_q = self.tgt(nxt).gather(1, next_actions)
            else:
                next_q = self.tgt(nxt).max(1, keepdim=True)[0]
            target = rew + (1.0 - done) * self.gamma * next_q

        td_error = target - q_values
        loss = (weights * (td_error ** 2)).mean()

        self.opt.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.q.parameters(), 10.0)
        self.opt.step()

        self.step_count += 1
        if self.step_count % self.target_update == 0:
            self.tgt.load_state_dict(self.q.state_dict())

        # eps decay
        self.eps = max(self.eps_end, self.eps * self.eps_decay)
        return float(loss.item())
