import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os
import json
import seaborn as sns
import glob
import gymnasium as gym
from hand_motion.commons import data_for_cylinder_along_z
from libs.trajectory import retrieve, fit_dmp
from libs.test import test_simple_goal

with open('param.json', 'r') as file:
    pr = json.load(file)


def demo_reward_histogram(root_path):

    demo_rew = np.load(root_path + pr['paths']['buffer'] + pr['files']['reward'])
    demo_rew_df = pd.DataFrame(demo_rew, columns=['Reward'])

    plt.figure(figsize=(3, 2))
    plt.grid(linestyle='-.')
    plt.xlim([-400, 0])
    sns.histplot(data=demo_rew_df, x="Reward", bins=5000, color=np.array([102, 163, 255])/256) #
    plt.subplots_adjust(top=0.88, bottom=0.215, left=0.165, right=0.9)
    plt.show()


def check_dmp_fitting(root_path):

    demo_file = glob.glob(root_path + pr['paths']['demo'] + '/recording_396.csv')[0]
    _, des_pos, _, _, _, _ = retrieve(pd.read_csv(demo_file), with_obstacle=True, degree=-21,
                                        normalize=True, avr_spd=pr["globals"]["normal_speed"])

    fig = plt.figure(figsize=(3, 2.5))
    ax = plt.axes(projection='3d')

    ax.plot3D(des_pos[0], des_pos[1], des_pos[2])

    from gymnasium.envs.registration import register
    register(
        id='edmp-v0',
        entry_point='edmp.envs:EDMPEnv'
    )

    env_fn = lambda : gym.make('edmp-v0')
    o, _, _, _, _ = fit_dmp(env_fn, demo_file, pr["globals"]["normal_speed"], pr["globals"]["sampling_time"])

    o_tp = np.transpose(o)
    ax.plot3D(o_tp[0], o_tp[1], o_tp[2])

    ax.set_xlim3d(-0.1, 0.6)
    ax.set_ylim3d(0, 0.5)
    ax.set_zlim3d(0, 0.3)

    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_zlabel("z")

    plt.show()


def draw_fixed_goal(root_path, exp_name):

    data_dir = root_path + pr["paths"]["test_simple_goals"] + '/' + exp_name
    datasets = []
    for i in range(len(pr["globals"]["seeds"])):

        f = os.path.join(data_dir, 'run_' + str(i), 'progress.txt')
        if not os.path.exists(f):
            test_simple_goal(root_path, exp_name, i)
        
        try:
            exp_data = pd.read_table(f)
        except:
            print('Could not read from %s' % f)
            continue
        datasets.append(exp_data)

    InitPos_x = datasets[0].InitPos_1[0]
    InitPos_y = datasets[0].InitPos_2[0]
    InitPos_z = datasets[0].InitPos_3[0]
    GoalPos_x = datasets[0].GoalPos_1[0]
    GoalPos_y = datasets[0].GoalPos_2[0]
    GoalPos_z = datasets[0].GoalPos_3[0]
    ObstPos_x = datasets[0].ObstPos_1[0]
    ObstPos_y = datasets[0].ObstPos_2[0]
    ObstPos_z = datasets[0].ObstPos_3[0]

    fig = plt.figure(figsize=(3, 2.5))
    ax = plt.axes(projection='3d')

    counter = 0
    for item in datasets:
        CurrPos_x = item.Pos_1.to_numpy()
        CurrPos_y = item.Pos_2.to_numpy()
        CurrPos_z = item.Pos_3.to_numpy()

        counter = counter + 1
        collision_counts = np.sum(item.CollisionFlag.to_numpy())

        if collision_counts > 0:
            ax.plot3D(CurrPos_x, CurrPos_y, CurrPos_z, color=(255 / 255, 51 / 255, 51 / 255), linewidth=1)
        else:
            ax.plot3D(CurrPos_x, CurrPos_y, CurrPos_z, color=(21 / 255, 21 / 255, 81 / 255), linewidth=1)

    ax.plot(InitPos_x, InitPos_y, InitPos_z, alpha=0.8, color=(0.2, 0.5, 1), marker="o", markersize=8)
    ax.plot(GoalPos_x, GoalPos_y, GoalPos_z, alpha=0.8, color=(128/255, 0, 127/255), marker="*", markersize=10)

    Xc, Yc, Zc = data_for_cylinder_along_z(ObstPos_x, ObstPos_y, pr["obstacle"]["radius"], ObstPos_z)
    ax.plot_surface(Xc, Yc, Zc, color=(255/255, 212/255, 128/255))

    plt.subplots_adjust(top=0.98, bottom=0.155, left=0.0, right=0.9)

    # 3d range wrt cam #
    ax.set_xlim3d(-0.1, 0.5)
    ax.set_ylim3d(0, 0.5)
    ax.set_zlim3d(0, 0.3)

    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_zlabel("z")
    # plt.savefig('test' + option_dir + '.eps', dpi=600)

    plt.show()
