import math
import matplotlib.pyplot as plt
from libs.spinup.plot import plot_data, get_datasets
import json
import pandas as pd

with open('param.json', 'r') as file:
    pr = json.load(file)


def learning_curves(root_path, option_dir):

    data_dir = root_path + pr["paths"]["training"]
    color_dict = {
        'DEMO-DMP': "#990000",
        'DDPG-DMP': "#660066",
        'IBC-DMP': "#1a4c1a",
        'IBC-DMP-1': "#1a4c1a",
        'EBC-DMP': "#003399",
        'EBC-DMP-1': "#003399",
        'DDPG': "#008000",
        'DMP': "#cca300"
    }
    ylim_dict = {
        'DEMO-DMP': [-10, -4],
        'DDPG-DMP': [-10, -4],
        'IBC-DMP': [-10, -4],
        'IBC-DMP-1': [-10, -4],
        'EBC-DMP': [-10, -4],
        'EBC-DMP-1': [-10, -4],
        'DDPG': [-12, -8],
        'DMP': [-10, -4]
    }

    data = get_datasets(data_dir + '/' + option_dir, condition='mean L-ARPE')

    for j in range(len(data)):
        for i in range(pr["globals"]["max_episodes"]):
            data[j].AverageEpRet[i] = -math.log(1-data[j].AverageEpRet[i])

    plt.figure(figsize=(4, 3))
    plot_data(data, xaxis='Epoch', value="AverageEpRet", smooth=10, color_list=color_dict[option_dir])
    plt.ylim(ylim_dict[option_dir])
    plt.xlim([0, pr["globals"]["max_episodes"]])
    plt.xlabel('#$\,$Interactions', fontsize=13)
    plt.ylabel('L-ARPE', fontsize=13)
    plt.ticklabel_format(axis="y", style="sci", scilimits=(-100, 100))
    plt.tick_params(axis='both', which='major', labelsize=13)
    plt.subplots_adjust(top=0.938, bottom=0.221, left=0.219, right=0.949)
    plt.show()