import numpy as np
from scipy.spatial.transform import Rotation as R
from scipy.signal import savgol_filter
from scipy import interpolate as inter
from .commons import nm2
import pandas as pd


def retrieve(data_frame, with_obstacle=False, degree=-21, normalize=True, avr_spd=0.2):
    r = R.from_euler('z', degree, degrees=True)
    t = np.array(data_frame.time)
    t = t - t[0]

    ty = np.array(data_frame.x_hand) - data_frame.x_hand[0]
    tz = np.array(data_frame.y_hand) - data_frame.y_hand[0] + 0.05
    tx = np.array(data_frame.z_hand) - data_frame.z_hand[0]
    traj = np.array([tx, ty, tz])
    traj_tp = np.transpose(traj)

    # Apply orientational calibration
    num_length = len(t)
    for j in range(num_length):
        traj_tp[j] = r.apply(traj_tp[j])

    # Get total length
    total_length = 0
    for i in range(num_length - 1):
        inc_len = nm2(traj_tp[i+1] - traj_tp[i])
        total_length += inc_len

    # Get total time and normalized traj
    if normalize:
        total_time = total_length / avr_spd
        t_scale = t * total_time / t[-1]
        sampling_time = np.mean(np.diff(t_scale))
    else:
        sampling_time = 0.01
        total_time = num_length * sampling_time
        t_scale = t

    # Get trajectory velocity
    traj_v_f = np.zeros([3, num_length])
    for i in range(3):
        traj_v = np.append([0], np.diff(traj[i]) / sampling_time)
        traj_v_f[i] = savgol_filter(traj_v, 15, 3)
    traj_v_int = np.transpose(traj_v_f)
    traj_v_int_abs = np.zeros([num_length, ])
    for i in range(num_length):
        traj_v_int_abs[i] = nm2(traj_v_int[i])

    # Get obstacle information
    if with_obstacle:
        oy = data_frame.x_obstacle - data_frame.x_hand[0]
        oz = data_frame.y_obstacle - data_frame.y_hand[0] + 0.05
        ox = data_frame.z_obstacle - data_frame.z_hand[0]
        obst_pos = np.array([ox, oy, oz])
        obst_pos_tp = np.transpose(obst_pos)

        for j in range(num_length):
            obst_pos_tp[j] = r.apply(obst_pos_tp[j])

        obst_pos = np.transpose(obst_pos_tp)

    else:
        obst_pos = []

    # Calculate max speed and average speed
    max_spd = max(traj_v_int_abs)
    if normalize:
        avg_spd_gt = avr_spd
    else:
        avg_spd_gt = total_length / total_time

    return t_scale, traj, traj_v_f, obst_pos, num_length, (total_length, total_time, max_spd, avg_spd_gt)


def resample(time, signal, resampling_time):
    f = inter.interp1d(time, signal, kind='slinear')
    resampled_signal = f(resampling_time)
    return resampled_signal


def label(env_fn, len_episode, des_pos, des_vel, obst_pos):

    des_pos_pt = np.transpose(des_pos)
    des_vel_pt = np.transpose(des_vel)
    obst_pos_pt = np.transpose(obst_pos)

    options = {'initial': des_pos_pt[0],
               'goal': des_pos_pt[-1],
               'obstacle': obst_pos_pt[0]}

    env = env_fn()

    a = np.zeros([len_episode, env.act_dim[0]])
    r = np.zeros([len_episode, 1])
    d = np.zeros([len_episode, 1])
    o = np.zeros([len_episode+1, env.obs_dim[0]])

    o[0], _ = env.reset(options=options)

    len_end = len_episode
    for t in range(len_episode):

        err_p = o[t][0:env.act_dim[0]] - des_pos_pt[t]
        err_v = o[t][env.act_dim[0]: env.act_dim[0]*2] - des_vel_pt[t]
        a[t] = (-1500 * err_p - 40 * err_v)/o[t][-1]
        o[t+1], r[t], ter, trn, _ = env.step(a[t])
        d[t] = ter | trn
        if d[t]:
            len_end = t+1
            break

    return o[:len_end], o[1:len_end+1], a[:len_end], r[:len_end], d[:len_end]


def fit_dmp(env_fn, file_name, avr_spd, sample_time):
    t_scale, des_pos, des_vel, obst_pos, _, _ = retrieve(pd.read_csv(file_name), with_obstacle=True, degree=-21,
                                                             normalize=True, avr_spd=avr_spd)

    rs_t_scale = np.arange(0, t_scale[-1], sample_time)
    rs_des_pos = np.zeros([3, len(rs_t_scale)])
    rs_des_vel = np.zeros([3, len(rs_t_scale)])
    rs_obst_pos = np.zeros([3, len(rs_t_scale)])

    for s in range(3):
        rs_des_pos[s] = resample(t_scale, des_pos[s], rs_t_scale)
        rs_des_vel[s] = resample(t_scale, des_vel[s], rs_t_scale)
        rs_obst_pos[s] = resample(t_scale, obst_pos[s], rs_t_scale)

    o, o2, a, r, d = label(env_fn, len(rs_t_scale), rs_des_pos, rs_des_vel, rs_obst_pos)

    return o, o2, a, r, d