from numpy.linalg import norm
import numpy as np
from enum import Enum
import torch
from device import device


class SimulationMode(Enum):
    ORIGINAL_DMP = 0
    ONLY_ROBOT = 1
    DMP_WITH_ROBOT = 2


def nm(vec):
    return norm(vec, ord=np.inf)


def nm2(vec):
    return norm(vec, ord=2)


def radius_to_degree(vec):
    out = np.zeros([len(vec)])
    for i in range(len(out)):
        out[i] = (vec[i] + 6.28) if vec[i] < 0 else vec[i]
    return out/3.14*180


def push(obj):
    if torch.is_tensor(obj):
        if not obj.is_cuda:
            obj = obj.to(device)
    else:
        print("Given parameter is not a tensor!")
    return obj


def pull(obj):
    if torch.is_tensor(obj):
        if obj.is_cuda:
            obj = obj.cpu()
    else:
        print("Given parameter is not a tensor!")
    return obj


def dist_to_cylinder(displacement, radius):
    if displacement[2] <= 0:
        dist = nm2(displacement[:2]) - radius
    elif displacement[2] > 0 and nm2(displacement[:2]) < radius:
        dist = displacement[2]
    else:
        dist = nm2(displacement)
    return dist


def potential_field(dist_to_goal, field_range, dist_tolerance):

    if dist_to_goal <= dist_tolerance:
        value = 100000
    elif (dist_to_goal > dist_tolerance) and (dist_to_goal < field_range):
        value = 1/(dist_to_goal-dist_tolerance)**2 - 1/(field_range-dist_tolerance)**2
    else:
        value = 0
    return value


def calculate_reward(counter, pos, acc, goal_pos, obst_pos, radius, threshold, ep_len):
    dist_to_goal = nm2(pos - goal_pos)
    dist_to_obst = dist_to_cylinder(pos - obst_pos, radius) - 0.035
    dist_to_ground = pos[2]

    # All distances have incorporated tolerances

    if dist_to_obst > 0 and dist_to_ground > 0:
        collision_count = 0
    else:
        collision_count = 1

    if counter >= ep_len:
        truncated = True
        terminated = False
        if dist_to_goal > threshold:
            reward = - (dist_to_goal - threshold) ** 2 * 100000
        else:
            reward = 0
    else:
        truncated = False
        if dist_to_goal > threshold:
            terminated = False
            reward = - nm2(acc) ** 2 * 0.001 - dist_to_goal ** 2 * 10 \
                     - 0.001 * potential_field(dist_to_obst, field_range=0.045, dist_tolerance=0.05) \
                     - 0.001 * potential_field(dist_to_ground + 0.05, field_range=0.05, dist_tolerance=0.01)
        else:
            terminated = True
            reward = 0

    # reward = -math.log(-reward+1)
    return reward, terminated, truncated, collision_count


def clean_up_ac(ac, depth, mean=0, std=0.1):
    for i in range(depth):
        torch.nn.init.zeros_(ac.pi.pi[2 * i].weight)
        torch.nn.init.normal_(ac.q.q[2 * i].weight, mean=mean, std=std)
        torch.nn.init.zeros_(ac.pi.pi[2 * i].bias)
        torch.nn.init.normal_(ac.q.q[2 * i].bias, mean=mean, std=std)