import numpy as np
import glob
import pandas as pd
from .trajectory import retrieve, fit_dmp
import json
import os
from libs.sample_positions import sample_positions_trivial, sample_positions_valid

with open('param.json', 'r') as file:
    p = json.load(file)
max_ep_len = int(p["globals"]["max_time"]/p["globals"]["sampling_time"])


def human_demo(env_fn, root_path):
    
    buffer_path = root_path + p["paths"]["buffer"]
    demo_path = root_path + p["paths"]["demo"]

    file_list = glob.glob(demo_path + '/*.csv')
    if not file_list:
        raise FileNotFoundError("Demonstration data not found. Refer to ReadMe for details.")
    
    os.makedirs(buffer_path, exist_ok=True)

    buf = []
    demo_rew = []
    num_list = len(file_list)
    for i in range(num_list):
        
        o, o2, a, r, d = fit_dmp(env_fn, file_list[i], p["globals"]["normal_speed"], p["globals"]["sampling_time"])
        demo_rew = np.append(demo_rew, sum(r))

        if np.any(buf):
            buf = np.concatenate([buf, np.concatenate([o, a, r, o2, d], axis=1)], axis=0)
        else:
            buf = np.concatenate([o, a, r, o2, d], axis=1)

        # To examine the efficacy of PID action generation
        # if i == 325:
        #    np.save(buffer_path + p["files"]["comparison"], o)

        print('Progress:' + str(int(i*100/num_list)) + '%')

    np.save(buffer_path + p["files"]["buffer"], buf)
    np.save(buffer_path + p["files"]["reward"], demo_rew)

    return None


def random_demo(env_fn, root_path):

    buffer_path = root_path + p["paths"]["buffer"]
    demo_path = root_path + p["paths"]["demo"]

    file_list = glob.glob(demo_path + '/*.csv')
    if not file_list:
        raise FileNotFoundError("Demonstration data not found. Refer to ReadMe for details.")

    os.makedirs(buffer_path, exist_ok=True)

    env = env_fn()
    buf = []
    num_list = len(file_list)

    for i in range(num_list):
        _, des_pos, _, obst_pos, _, _ = retrieve(pd.read_csv(file_list[i]), with_obstacle=True, degree=-21,
                                                            normalize=True, avr_spd=p["globals"]["normal_speed"])

        options = {'initial': des_pos.T[0],
                   'goal': des_pos.T[-1],
                   'obstacle': obst_pos.T[0]}

        o, _ = env.reset(options=options)
        ter, ep_len = 0, 0

        while not (max_ep_len == ep_len):
            a = env.action_space.sample()
            o2, r, ter, trn, _ = env.step(a)
            d = ter | trn
            ep_len += 1
            o = o2

            if np.any(buf):
                buf = np.concatenate((buf, [np.concatenate((o, a, r, o2, d), axis=None)]), axis=0)
            else:
                buf = [np.concatenate((o, a, r, o2, d), axis=None)]

        print('Progress:' + str(int(i * 100 / num_list)) + '%')

    np.save(buffer_path + p["files"]["buffer_no_bc"], buf)

    return None


def original(env_fn, root_path, total_steps):

    env = env_fn()
    buf = []

    o, _ = env.reset(options=sample_positions_trivial(p["positions"], p["sample_ranges_trivial"])[0])
    ep_len = 0

    for t in range(total_steps):

        a = env.action_space.sample()
        o2, r, ter, trn, _ = env.step(a)
        d = ter | trn
        ep_len += 1
        d = False if ep_len == max_ep_len else d

        if np.any(buf):
            buf = np.concatenate((buf, [np.concatenate((o, a, r, o2, d), axis=None)]), axis=0)
        else:
            buf = [np.concatenate((o, a, r, o2, d), axis=None)]
        print('Progress:' + str(int(t * 100 / total_steps)) + '%')

        o = o2

        if d or (ep_len == max_ep_len):

            o, _ = env.reset(options=sample_positions_trivial(p["positions"], p["sample_ranges_trivial"])[0])
            ep_len = 0
        
    buffer_path = root_path + p["paths"]["buffer"]
    np.save(buffer_path + p["files"]["buffer_original"], buf)

    return None