import gymnasium as gym
import json
import torch
from libs.spinup.logx import EpochLogger
from device import device
from libs.sample_positions import sample_positions_valid


env_name = {
    'IBC-DMP':      'edmp-v0',
    'DEMO-DMP':     'edmp-v0',
    'DDPG-DMP':     'edmp-v0',
    'EBC-DMP':      'edmp-v0'
}

env_entry = {
    'edmp-v0': 'edmp.envs:EDMPEnv'
}

with open('param.json', 'r') as file:
    pr = json.load(file)


def test_simple_goal(root_path, exp_name, num=0):

    from gymnasium.envs.registration import register
    register(
        id=env_name[exp_name],
        entry_point=env_entry[env_name[exp_name]]
    )

    env_fn = lambda: gym.make(env_name[exp_name])
    env = env_fn()
    max_ep_len = int(pr["globals"]["max_time"]/pr["globals"]["sampling_time"])
    options = pr["positions"]

    logger = EpochLogger(output_dir=root_path + pr["paths"]["test_simple_goals"] + '/' + exp_name + '/run_' + str(num))
    logger.save_config(locals()) if exp_name != 'dmp_original' else None

    # Load policy
    ac = torch.load(root_path + pr["paths"]["training"] + '/' + exp_name + '/run_' + str(num) + '/ddpg_model.pt', map_location=device)

    obs, _ = env.reset(options = options)
    ter, tru, ep_ret, ep_len = 0, 0, 0, 0
    while not (ter or tru or (max_ep_len == ep_len)):

        # Take action
        action = ac.act(torch.tensor(obs, dtype=torch.float32))
        
        obs, r, ter, tru, INFO = env.step(action)
        ep_ret += r
        ep_len += 1
        logger.log_tabular('Time', ep_len * pr["globals"]["sampling_time"])
        for i in range(3):
            logger.log_tabular('InitPos_' + str(i + 1), options['initial'][i])
            logger.log_tabular('GoalPos_' + str(i + 1), options['goal'][i])
            logger.log_tabular('ObstPos_' + str(i + 1), options['obstacle'][i])
            logger.log_tabular('Pos_' + str(i + 1), obs[i])
            logger.log_tabular('Vel_' + str(i + 1), obs[i+3])
        logger.log_tabular('CollisionFlag', INFO['collision_count'])
        logger.log_tabular('TestEpRet', ep_ret)
        logger.dump_tabular()


def test_random_goals(root_path, num_goal, exp_name, num=0):
    
    from gymnasium.envs.registration import register
    register(
        id=env_name[exp_name],
        entry_point=env_entry[env_name[exp_name]]
    )

    env_fn = lambda: gym.make(env_name[exp_name])
    env = env_fn()
    max_ep_len = int(pr["globals"]["max_time"]/pr["globals"]["sampling_time"])
    
    logger = EpochLogger(output_dir=root_path + pr["paths"]["test_random_goals"] + '/' + exp_name + '/run_' + str(num))
    logger.save_config(locals()) if exp_name != 'dmp_original' else None

    # Load policy
    ac = torch.load(root_path + pr["paths"]["training"] + '/' + exp_name + '/run_' + str(num) + '/ddpg_model.pt', map_location=device)

    for n in range(num_goal):

        options = sample_positions_valid(pr["positions"], pr["sample_ranges_valid"])[0]
        obs, _ = env.reset(options=options)
        ter, tru, ep_ret, ep_len, collision_counts = 0, 0, 0, 0, 0
        while not (ter or tru or (max_ep_len == ep_len)):

            # Take action
            action = ac.act(torch.tensor(obs, dtype=torch.float32))

            obs, r, ter, tru, INFO = env.step(action)
            ep_ret += r
            ep_len += 1
            collision_counts += INFO['collision_count']

        for i in range(3):
            logger.log_tabular('InitPos_' + str(i+1), options['initial'][i])
            logger.log_tabular('GoalPos_' + str(i+1), options['goal'][i])
            logger.log_tabular('ObstPos_' + str(i+1), options['obstacle'][i])
        logger.log_tabular('CollisionFlag', collision_counts)
        logger.log_tabular('TestEpRet', ep_ret)
        logger.dump_tabular()
