import os
import numpy as np
import copy
import matplotlib.pyplot as plt
import fiqus.utils.Utils as utl

'''
This simple script plots the results of the simple thermal model for different mesh sizes of the insulation.
It requires that the simple thermal model has been run for different mesh sizes of the insulation, 
both for TSA and surface mesh insulation. This can be done by running the run_simple_thermal_model.py script.
The surface mesh insulation is referred to by the subscript _ref in this script.
'''

# Input parameters/settings
# name of the magnet, this should correspond to the name of the magnet in the run_simple_thermal_model.py
magnet = "simple_thermal_model"
# mesh size of insulation in meter, this should correspond to the values that were already run
target_size_insulation = [1e-6, 10**(-5.66), 10**(-5.33), 1e-5, 10**(-4.66), 10**(-4.33), 1e-4]

# what indices to plot for the temperature over time
# 0 is the first mesh size, 1 is the second mesh size, etc.
# the default values are to plot the results the finest mesh size 
# (1 for reference since 0 is the finest mesh size that is not run for TSA)
plot_T_over_time_idx_ref = 6
plot_T_over_time_idx_tsa = 5

# preparing some paths
output_path = os.path.join(os.getcwd(), 'output_thesis', magnet )
paths = {'Ref': os.path.join(output_path, magnet + '_ref'),
         'TSA': os.path.join(output_path, 'Geometry_TSA')}


## some helper functions
def parse_values(file_name, case, mesh_folder, values_dict):
    '''Parse values from FiQuS output files'''
    t_steps = []
    values_dict_keys = list(values_dict.keys())
    if file_name == 'T_avg' and case == 'TSA':
        for el, val in values_dict.items():
            with open(os.path.join(paths[case], mesh_folder, file_name, f"{file_name}_{values_dict_keys.index(el)}.txt")) as f:
                for line in f.read().splitlines():
                    row_entries = line.split()
                    if el == 1: t_steps.append(float(row_entries[0]))
                    val.append(utl.GeometricFunctions.sig_dig(float(row_entries[1])))
    else:
        for el, val in values_dict.items():
            with open(os.path.join(paths[case], mesh_folder, f"{file_name}_{values_dict_keys.index(el)}.txt")) as f:
                for line in f.read().splitlines():
                    row_entries = line.split()
                    if el == 1: t_steps.append(float(row_entries[0]))
                    val.append(utl.GeometricFunctions.sig_dig(float(row_entries[1])))
    return t_steps

# Plot Settings
line_styles = {'Ref': 'solid', 'TSA': 'dashed', 'QH': 'dashed'}
QH_line_styles = ['dashed', 'dotted', 'dashdot', 'solid']
marker_frequency = {'Ref': 0.2, 'TSA': 0.23, 'QH': 0}
markers = ['o', 's', 'D', '^', '*', 'x']
marker_size = 12
colors = {'Ref': plt.cm.viridis(0.25), 'TSA': plt.cm.viridis(0.75)}

# Initialize plots
fig_err, ax_err = plt.subplots()  
ax_err.set_xscale("log", base=10)
ax_err.set_xlabel('Element size [m]')
ax_err.set_ylabel('Relative error [-]')
ax_err.tick_params(axis='x')
ax_err.tick_params(axis='y')
ax_err.set_yscale("log", base=10)

fig_T, ax_T = plt.subplots()
ax_T.set_xlabel('Time [s]')
ax_T.set_ylabel('Average HT temperature [K]')

fig_err_time, ax_err_time = plt.subplots()
ax_err_time.set_xlabel('Time [s]')
ax_err_time.set_ylabel('Relative error [-]')

ax_err_time.set_yscale("log", base=10)


# run reference with surface insulation mesh
output_path_ref = os.path.join(output_path, magnet + '_ref')
for i, size in enumerate(target_size_insulation):
    
    T_ref = {1: [], 2: [], 3: [], 4: []}
    time_steps = parse_values(file_name='T_avg', case='Ref', mesh_folder=str(size), values_dict=T_ref)

    # this is the finest surface insulation mesh --> our reference
    if i == 0:
        T_ref_finest = copy.deepcopy(T_ref)
    else:
    # otherwise calculate relative error
        shortest_array = {ht: min(len(T_ref[ht]), len(T)) for ht, T in T_ref_finest.items()}
        max_error_per_ht_list = [
                np.max(np.abs(np.array(T_ref[ht][:shortest_array[ht]]) - np.array(T[:shortest_array[ht]])) / T[:shortest_array[ht]]
                ) for ht, T in T_ref_finest.items()]
        
        rel_error_ref = np.max(max_error_per_ht_list)
        ax_err.plot(size, rel_error_ref, color=colors['Ref'], linestyle='None', marker=markers[1], markersize=marker_size, label='Ref' if i == 1 else None)


        rel_error_time_ht = np.max([np.abs(np.array(T_ref[ht][:shortest_array[ht]]) - np.array(T[:shortest_array[ht]])) / T[:shortest_array[ht]] for ht, T in T_ref_finest.items()], axis = 0)

        if i == 1:                      
            rel_error_time_ref = rel_error_time_ht
            rel_error_mesh_ref = rel_error_ref
            rel_error_mesh_sizes = size
        else:
            rel_error_time_ref = np.vstack((rel_error_time_ref, rel_error_time_ht))
            rel_error_mesh_ref = np.vstack((rel_error_mesh_ref, rel_error_ref))
            rel_error_mesh_sizes = np.vstack((rel_error_mesh_sizes, size))

    if i == plot_T_over_time_idx_ref:
        for ht, T in T_ref.items():
            ax_T.plot(time_steps, T, label=f'Ref HT {ht}', color=colors['Ref'], linestyle=line_styles['Ref'], markersize=marker_size, markevery=marker_frequency['Ref'], marker=markers[ht-1])

            if i == plot_T_over_time_idx_ref and ht == 1:
                T_time_ref = T
            else:
                T_time_ref = np.vstack((T_time_ref, T))

# TSA
for i, size in enumerate(target_size_insulation[1:]):

    T_tsa = {1: [], 2: [], 3: [], 4: []}
    time_steps = parse_values(file_name='T_avg', case='TSA', mesh_folder=os.path.join(f'Mesh_{size}', 'Solution_1'), values_dict=T_tsa)
    shortest_array = {ht: min(len(T_tsa[ht]), len(T)) for ht, T in T_ref_finest.items()}

    max_error_per_ht_list = [
            np.max(np.abs(np.array(T_tsa[ht][:shortest_array[ht]]) - np.array(T[:shortest_array[ht]])) / T[:shortest_array[ht]]
            )
        for ht, T in T_ref_finest.items()]
    rel_error_ref = np.max(max_error_per_ht_list)
    ax_err.plot(size, rel_error_ref, color=colors['TSA'], linestyle='None', marker=markers[0], markersize=marker_size, label='TSA' if i + 2 == len(target_size_insulation) else None)

    rel_error_time_ht = np.max([np.abs(np.array(T_tsa[ht][:shortest_array[ht]]) - np.array(T[:shortest_array[ht]])) / T[:shortest_array[ht]] for ht, T in T_ref_finest.items()], axis = 0)

    if i == 0:
        rel_error_mesh_tsa = rel_error_ref                      
        rel_error_time_tsa = rel_error_time_ht
    else:
        rel_error_mesh_tsa = np.vstack((rel_error_mesh_tsa, rel_error_ref))
        rel_error_time_tsa = np.vstack((rel_error_time_tsa, rel_error_time_ht))
        
    if i == plot_T_over_time_idx_tsa:
        for ht, T in T_tsa.items():
            ax_T.plot(time_steps, T, label=f'TSA HT {ht}', color=colors['TSA'], linestyle=line_styles['TSA'], markersize=marker_size, markevery=marker_frequency['TSA'], marker=markers[ht-1])

            if i == plot_T_over_time_idx_tsa and ht == 1:
                T_time_tsa = T
            else:
                T_time_tsa = np.vstack((T_time_tsa, T))

# plotting and saving data to .csv

# error over mesh
fig_err.legend(loc='upper left', bbox_to_anchor=(0.1,0.9))
fig_err.savefig(os.path.join(output_path, 'error_over_mesh.pdf'), format='pdf', bbox_inches='tight')
fig_err.savefig(os.path.join(output_path, 'error_over_mesh.jpg'), format='jpg', bbox_inches='tight')

np.savetxt(os.path.join(output_path, 'time_steps.csv'), np.array(time_steps), delimiter=",")

np.savetxt(os.path.join(output_path, 'rel_error_mesh_ref.csv'), rel_error_mesh_ref, delimiter=",")
np.savetxt(os.path.join(output_path, 'rel_error_mesh_sizes.csv'), rel_error_mesh_sizes, delimiter=",")
np.savetxt(os.path.join(output_path, 'rel_error_mesh_tsa.csv'), rel_error_mesh_tsa, delimiter=",")

# temperature over time
fig_T.legend(loc='upper left', bbox_to_anchor=(0.1,0.9))
fig_T.savefig(os.path.join(output_path, 'T_over_time.pdf'), format='pdf', bbox_inches='tight')

np.savetxt(os.path.join(output_path, 'T_time_ref.csv'), T_time_ref.T, delimiter=",")

np.savetxt(os.path.join(output_path, 'T_time_tsa.csv'), T_time_tsa.T, delimiter=",")

# error over time
for i, size_insulation in enumerate(target_size_insulation[1:]): 
    ax_err_time.plot(time_steps, rel_error_time_ref[i], label=f'Ref {size_insulation} m',color=colors['Ref'], linestyle=line_styles['Ref'], markersize=marker_size, markevery=marker_frequency['Ref'], marker=markers[i])

    ax_err_time.plot(time_steps, rel_error_time_tsa[i], label=f'TSA {size_insulation} m',color=colors['TSA'], linestyle=line_styles['TSA'], markersize=marker_size, markevery=marker_frequency['TSA'], marker=markers[i])

fig_err_time.legend(loc='upper left', bbox_to_anchor=(0.1,0.9))
fig_err_time.savefig(os.path.join(output_path, 'error_over_time.pdf'), format='pdf', bbox_inches='tight')
np.savetxt(os.path.join(output_path, 'rel_error_time_ref.csv'),     	rel_error_time_ref.T, delimiter=",")

np.savetxt(os.path.join(output_path, 'rel_error_time_tsa.csv'),     	rel_error_time_ref.T, delimiter=",")

plt.show()