
import argparse, yaml, os, json
from rlmots.env import CloudEdgeEnv
from rlmots.dqn_agent import DQNAgent

def run(cfg):
    env = CloudEdgeEnv(cfg)
    obs = env.state()
    agent = DQNAgent(obs_dim=len(obs), act_dim=env.action_space, cfg=cfg["dqn"])
    logs = {"rewards": [], "loss": []}
    for ep in range(cfg["episodes"]):
        env = CloudEdgeEnv(cfg)
        obs = env.state()
        ep_reward = 0.0
        done = False
        while not done:
            act = agent.act(obs, env.action_space)
            nxt, rew, done, info = env.step(act)
            agent.push((obs, act, rew, nxt, float(done)))
            loss = agent.update()
            obs = nxt
            ep_reward += rew
            if loss:
                logs["loss"].append(loss)
        logs["rewards"].append(ep_reward)
        print(f"Episode {ep+1}/{cfg['episodes']} reward={ep_reward:.3f} eps={agent.eps:.3f}")
    return logs

if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("--config", required=True)
    ap.add_argument("--outdir", default="results/google2019_small")
    args = ap.parse_args()
    with open(args.config, "r") as f:
        cfg = yaml.safe_load(f)
    os.makedirs(args.outdir, exist_ok=True)
    logs = run(cfg)
    with open(os.path.join(args.outdir, "logs.json"), "w") as f:
        json.dump(logs, f, indent=2)
    print("Saved results to", args.outdir)
