import os
import matplotlib.pyplot as plt
import numpy as np
from scipy.interpolate import interp1d
from steam_sdk.analyses.AnalysisSTEAM import AnalysisSTEAM
import pandas as pd
from matplotlib import colormaps

## Input parameters/settings
multipole_name = "SMC" # "MBH_1in1" # magnet name
geometry_folders = ["Geometry_TSA", "Geometry_REF"]
ht_compare = [120] # what half turn temperature to plot over time

# Helper function to interpolate the data
def interpolate_column(column):
    return new_time_instants if column.name == 'Time' \
        else interp1d(df['Time'], column, kind='cubic', bounds_error=True)(new_time_instants)


## Grab data
a = AnalysisSTEAM(file_name_analysis='input/FiQuS.yaml', verbose=True)

# reading the mesh data and categorizing it

mesh_type_nums = []
mesh_type = []
mesh_nums = []
mesh_nums_ref = []
mesh_folders = []

for idx, geometry_name in enumerate(geometry_folders):
    mesh_path = os.path.join(os.getcwd(), "output_thesis", multipole_name, geometry_name)
    mesh_folders_idx = [entry for entry in os.listdir(mesh_path) if os.path.isdir(os.path.join(mesh_path, entry))]

    mesh_folders.extend([os.path.join(mesh_path, entry) for entry in mesh_folders_idx])
    mesh_type_num_list = [mesh_name.split('Mesh_')[1] for mesh_name in mesh_folders_idx]
    mesh_type_nums.extend([mesh_name.split('Mesh_')[1] for mesh_name in mesh_folders_idx])
    mesh_type.extend([mesh_type_num.split('_')[0] for mesh_type_num in mesh_type_num_list])
    mesh_nums.extend([float(mesh_type_num.split('_')[1]) for mesh_type_num in mesh_type_num_list])
    
mesh_nums_ref = [ mesh_nums[i] if mesh_type[i] == 'REF' else np.inf for i in range(len(mesh_nums))]
finest_mesh = mesh_folders[mesh_nums_ref.index(min(mesh_nums_ref))]
print(f"fines mesh is {finest_mesh}")


data_frames, data_frames_new = [pd.DataFrame(), pd.DataFrame()], [pd.DataFrame(), pd.DataFrame()]
data_frames[0] = pd.read_csv(os.path.join(mesh_path, finest_mesh, 'Solution_1', 'half_turn_temperatures_over_time.csv'))

## Plotting preparations
line_styles = {'REF': 'dashed', 'TSA': 'solid'}
markers = {'REF': 'o', 'TSA': 's'}

plt.figure(0)
fig_err_over_mesh = plt.axes()
fig_err_over_mesh.set_yscale("log", base=10)
fig_err_over_mesh.set_xscale("log", base=10)
fig_err_over_mesh.set_xlabel('Element size [mm]')
fig_err_over_mesh.set_ylabel('Relative error [%]')
fig_err_over_mesh.tick_params(axis='x')
fig_err_over_mesh.tick_params(axis='y')

plt.figure(1)
fig_err_T_max_over_mesh = plt.axes()
fig_err_T_max_over_mesh.set_yscale("log", base=10)
fig_err_T_max_over_mesh.set_xscale("log", base=10)
fig_err_T_max_over_mesh.set_xlabel('Element size [mm]')
fig_err_T_max_over_mesh.set_ylabel('Relative error T hotspot [%]')
fig_err_T_max_over_mesh.tick_params(axis='x')
fig_err_T_max_over_mesh.tick_params(axis='y')

plt.figure(2)
fig_T_over_t = plt.axes()
fig_T_over_t.set_xlabel('Time [s]')
fig_T_over_t.set_ylabel(f'Temperature of half turn(s) {ht_compare} [K]')
fig_T_over_t.tick_params(axis='x')
fig_T_over_t.tick_params(axis='y')

plt.figure(3)
fig_err_maxT_over_t = plt.axes()
fig_err_maxT_over_t.set_xlabel('Time [s]')
fig_err_maxT_over_t.set_ylabel('Rel. error hotspot T [%]')
fig_err_maxT_over_t.tick_params(axis='x')
fig_err_maxT_over_t.tick_params(axis='y')

plt.figure(4)
fig_maxT_over_t = plt.axes()
fig_maxT_over_t.set_xlabel('Time [s]')
fig_maxT_over_t.set_ylabel('Hotspot T [K]')
fig_maxT_over_t.tick_params(axis='x')
fig_maxT_over_t.tick_params(axis='y')

plt.figure(5)
fig_maxErr_over_t = plt.axes()
fig_maxErr_over_t.set_xlabel('Time [s]')
fig_maxErr_over_t.set_ylabel('Max rel. error of all HT [%]')

fig_maxErr_over_t.tick_params(axis='x')
fig_maxErr_over_t.tick_params(axis='y')

# Get evenly spaced viridis colormap samples
cmap = colormaps['viridis']
viridis_colors = np.vstack((cmap(np.linspace(0, 1, int(len(mesh_nums)/2)+1)), cmap(np.linspace(0, 1, int(len(mesh_nums)/2)+1))))


for j, ht in enumerate(ht_compare):
    fig_T_over_t.plot(data_frames[0]['Time'], data_frames[0]['HT' + str(ht)], color=viridis_colors[0],
                    linestyle=line_styles['REF'],
                    label='REF')

fig_maxT_over_t.plot(data_frames[0]['Time'], np.nanmax(data_frames[0].drop(columns='Time').to_numpy(), axis=1), color=viridis_colors[0],
            linestyle=line_styles['REF'],
            label='REF')

txt_output_dir = os.path.join("output_thesis" , multipole_name)

os.makedirs(txt_output_dir, exist_ok=True)
np.savetxt(os.path.join(txt_output_dir, "T_hotspot_ref.csv"), np.vstack((data_frames[0]['Time'].to_numpy(), np.nanmax(data_frames[0].drop(columns='Time').to_numpy(), axis = 1))).T, delimiter=',')

chosen = 0
for i, mesh in enumerate(mesh_folders):
    if mesh != finest_mesh:
        try:
            data_frames[1] = pd.read_csv(os.path.join(mesh, 'Solution_1', 'half_turn_temperatures_over_time.csv'))
        except:
            continue    
        for j, ht in enumerate(ht_compare):
            fig_T_over_t.plot(data_frames[1]['Time'], data_frames[1]['HT' + str(ht)], color=viridis_colors[i],
                            linestyle=line_styles[mesh_type[i]],
                            
                            #markevery=marker_frequency['TSA'],
                            label=mesh_type_nums[i])
            
        fig_maxT_over_t.plot(data_frames[1]['Time'], np.nanmax(data_frames[1].drop(columns='Time').to_numpy(), axis=1), color=viridis_colors[i],
                        linestyle=line_styles[mesh_type[i]],
                        label=mesh_type_nums[i])

        new_time_instants = pd.Series(np.linspace(0, min(data_frames[0]['Time'][:].max(), data_frames[1]['Time'][:].max()), 2500))

        df = data_frames[0]
        data_frames_new[0] = df.apply(interpolate_column)
        df = data_frames[1]
        data_frames_new[1] = df.apply(interpolate_column)

        np.savetxt(os.path.join(txt_output_dir, f"T_hotspot_{mesh_type_nums[i]}.csv"), np.vstack((data_frames_new[1]['Time'], np.nanmax(data_frames_new[1].drop(columns='Time').to_numpy(), axis=1))).T, delimiter=',')

        data_frame_compare = (data_frames_new[0] - data_frames_new[1]).abs()
        data_frame_compare = data_frame_compare.drop(columns='Time')
        data_frames_ref = data_frames_new[0].drop(columns='Time')

        data_frame_non_ref = data_frames_new[1].drop(columns='Time')
        max_temp_ref = np.nanmax(data_frames_ref.to_numpy(), axis=1)
        max_temp_non_ref = np.nanmax(data_frame_non_ref.to_numpy(), axis=1)
        max_T_err_over_t = np.abs(max_temp_ref - max_temp_non_ref) / max_temp_ref

        fig_err_maxT_over_t.plot(new_time_instants, max_T_err_over_t * 100, label=mesh_type_nums[i], linestyle=line_styles[mesh_type[i]], color=viridis_colors[i])
        
        np.savetxt(os.path.join(txt_output_dir, f"err_hotspot_{mesh_type_nums[i]}.csv"), np.vstack((new_time_instants, max_T_err_over_t)).T, delimiter=',')

        fig_maxErr_over_t.plot(new_time_instants, np.nanmax(data_frame_compare.to_numpy() / data_frames_ref.to_numpy(), axis=1) * 100, label=mesh_type_nums[i], linestyle=line_styles[mesh_type[i]], color=viridis_colors[i])

        np.savetxt(os.path.join(txt_output_dir, f"err_{mesh_type_nums[i]}.csv"), np.vstack((new_time_instants, np.nanmax(data_frame_compare.to_numpy() / data_frames_ref.to_numpy(), axis=1))).T, delimiter=',')

        # print(f"Max error for mesh size {mesh_nums[i] * 1e3} mm is in half turn {np.argmax(data_frame_compare.to_numpy() / data_frames_ref.to_numpy(), axis=1)}.")

        max_T_err_hotspot = np.nanmax(max_T_err_over_t)
        max_T_err = np.nanmax(data_frame_compare.to_numpy() / data_frames_ref.to_numpy())

        print(f"For mesh size {mesh_nums[i] * 1e3} mm, the relative error is {max_T_err * 100} % and the relative error of the hotspot is {max_T_err_hotspot * 100} %. The maximum is in half turn {np.argmax(data_frame_compare.to_numpy() / data_frames_ref.to_numpy(), axis=1)}.")

        fig_err_T_max_over_mesh.plot(mesh_nums[i] * 1e3, max_T_err_hotspot * 100, color=viridis_colors[0] if mesh_type[i] == 'REF' else viridis_colors[-1],
                        linestyle='None', marker=markers[mesh_type[i]], label=mesh_type[i] if (i == 1 or i == len(mesh_folders) - 1) else None)
            
        fig_err_over_mesh.plot(mesh_nums[i] * 1e3, max_T_err * 100, color=viridis_colors[0] if mesh_type[i] == 'REF' else viridis_colors[-1],
                    linestyle='None', marker=markers[mesh_type[i]], label=mesh_type[i] if (i == 1 or i == len(mesh_folders) - 1) else None)
        
        # print(f"For mesh size {mesh_nums[i] * 1e3} mm, the relative error is {max_T_err * 100} %")

for id in plt.get_fignums():
    plt.figure(id)
    plt.legend()
    
plt.show()