import gymnasium as gym
import numpy as np
from libs.commons import nm, calculate_reward
import copy
import json


with open('param.json', 'r') as file:
    pr = json.load(file)


class EDMPEnv(gym.Env):
    metadata = {'render_modes': ['human']}
    
    def __init__(self, seed=0):

        # self.np_random, _ = gym.utils.seeding.np_random(seed)

        self.dof = pr["dmp"]["dof"]
        self.act_dim = (self.dof, )
        self.init_position = None
        self.goal_position = None
        self.obst_position = None

        # Configure DMP simulation
        self.x = None
        self.xd = None
        self.xdd = None
        self.s = None
        self.counter = None
        
        obs, _ = self.reset(options=pr["positions"])
        self.obs_dim = np.shape(obs)
        self.observation_space = gym.spaces.box.Box(-np.inf, np.inf, self.obs_dim, dtype=np.float32, seed=seed)
        self.action_space = gym.spaces.box.Box(-5, 5, self.act_dim, dtype=np.float32, seed=seed)

    def step(self, action):

        #si = 1 - 1 / (1 + math.exp(-(0.2 * (self.counter - 50))))
        #si *= 2.5

        self.xdd = pr["dmp"]["tau"] * pr["dmp"]["alpha"] * (pr["dmp"]["beta"] * (self.goal_position - self.x) - self.xd) + self.s * nm(
                    self.goal_position - self.init_position) * action.astype(np.float32)
        self.s += - pr["dmp"]["tau"] * pr["globals"]["sampling_time"] * pr["dmp"]["omega"] * self.s #+ tau * Ts * omega * si
        self.x += pr["globals"]["sampling_time"] * self.xd
        self.xd += pr["globals"]["sampling_time"] * self.xdd
        self.counter += 1

        reward, terminated, truncated, collision_count = calculate_reward(self.counter, 
                                                                          self.x,
                                                                          self.xdd,
                                                                          self.goal_position,
                                                                          self.obst_position,
                                                                          radius = pr["obstacle"]["radius"],
                                                                          threshold = pr["obstacle"]["reach_tolerance"],
                                                                          ep_len = int(pr["globals"]["max_time"]/pr["globals"]["sampling_time"]))
        obs = np.append(self.x, self.xd)
        obs = np.append(obs, self.x - self.obst_position)
        obs = np.append(obs, self.s)

        return obs, reward, terminated, truncated, {'collision_count': collision_count}

    def reset(self, seed=None, options=None):

        self.init_position = np.array(options['initial']).astype(np.float32)
        self.goal_position = np.array(options['goal']).astype(np.float32)
        self.obst_position = np.array(options['obstacle']).astype(np.float32)

        self.x = copy.copy(self.init_position)
        self.xd = np.zeros([self.dof, ], dtype=np.float32)
        self.xdd = np.zeros([self.dof, ], dtype=np.float32)
        self.s = np.array([1], dtype=np.float32)
        self.counter = 0
        obs = np.append(self.x, self.xd)
        obs = np.append(obs, self.x - self.obst_position)
        obs = np.append(obs, self.s)
        return obs, {}

    def render(self, mode='human'):
        return

    def close(self):
        return
