from matplotlib import pyplot as plt
import numpy as np
from getdp_create_preresolution import getdp_create_preresolution
from getdp_runner import solve_ivp_getdp
from parareal import solve_ivp_parareal
from datetime import datetime
import os
import time
from mpi4py import MPI
import sys
from scipy import integrate

comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()

# IMPORTANT: adapt path to your own system
getdp_path = "C:/Program Files/CERNGetDP/cerngetdp_2025.2.1/getdp_2025.2.1.exe"
output_path = "output"

if size > 1:
    n_time_wind = size
else:   
    n_time_wind = 24
    
solve_parallel = True
t_range = np.array([0, 0.6])

# Fine model
solver_f = "solve_ivp_getdp"
fun_f = "getdp_models/coupled_pancake_thesis/ERIK_THESIS_PARAREAL.pro"
getdp_opts_f = {
              "exe": getdp_path,
              "mesh": "getdp_models/coupled_pancake_thesis/ERIK_THESIS_PARAREAL.msh",
              "verbose": "0",
              "nl_iteration": 1,
              "write_solution_to_disk": "all",
              "adaptive_time": 1,
              "rel_tol_adaptive": 0,
              "abs_tol_adaptive": 0.1,
              "t_step_init": 0.001,
              "t_step_min": 1e-12,
              "t_step_max": 1,
              "rel_tol_nl": 0,
              "abs_tol_nl": 0.1,
              "n_max_it_nl": 10,
              "relax_factor": 0.7
              }

# Coarse model
solver_c = "solve_ivp_getdp"
fun_c = "getdp_models/coupled_pancake_thesis/ERIK_THESIS_PARAREAL.pro"
getdp_opts_c = {
              "exe": getdp_path,
              "mesh": "getdp_models/coupled_pancake_thesis/ERIK_THESIS_PARAREAL.msh",
              "verbose": "0", 
              "nl_iteration": 1,
              "write_solution_to_disk": "all",
              "rel_tol_nl": 0,
              "abs_tol_nl": 0.1,
              "rel_tol_adaptive": 0,
              "abs_tol_adaptive": 0.1,
              "n_max_it_nl": 10,
              "relax_factor": 0.7,
              "adaptive_time": 0,
              "t_step_init": 0.0001,
              "t_step_min": 0,
              "t_step_max": 1,
              "time_step": 0.1
              }

getdp_opts_c_first = {
              "exe": getdp_path,
              "mesh": "getdp_models/coupled_pancake_thesis/ERIK_THESIS_PARAREAL.msh",              
              "verbose": "0",
              "write_solution_to_disk": "all",
              "adaptive_time": 1,
              "rel_tol_adaptive": 0,
              "abs_tol_adaptive": 0.1,
              "t_step_init": 0.01,
              "t_step_min": 0,
              "t_step_max": 1,
              "nl_iteration": 1,
              "rel_tol_nl": 0,
              "abs_tol_nl": 0.1,
              "n_max_it_nl": 10,
              "relax_factor": 0.7,
              "track_times": True,
              }

# Parareal options 
para_opts = {"n_time_windows": n_time_wind, "NIter": n_time_wind, "plotEach": True, "joblib_parallel": True, "NProc": n_time_wind, "abs_tol": 1e-3, "rel_tol": 0, "first_coarse_adaptive": True, "output_folder": output_path}

numdof = getdp_create_preresolution(fun_c, getdp_opts_c)
init = 15 * np.zeros(numdof)

# Parareal solution
start_parallel = time.time()
if solve_parallel:
    parallel_sol = solve_ivp_parareal(t_range, init, fun_c, solver_c, getdp_opts_c, fun_f, solver_f, getdp_opts_f, para_opts, getdp_opts_c_first)
else:
    time.sleep(1)
    parallel_sol = integrate._ivp.ivp.OdeResult(t=[[ [] for _ in range(5) ]], y=[[] for _ in range(5)], err=[[] for _ in range(5)], trange=[[] for _ in range(5)], y_all=[[] for _ in range(5)], t_all=[[] for _ in range(5)], sol_c_all=[[] for _ in range(5)], t_c_all=[[] for _ in range(5)], y_events=[[] for _ in range(5)], t_events=[[] for _ in range(5)])
end_parallel = time.time()
t_parallel = end_parallel - start_parallel
print(f"Parareal took {t_parallel} s")


if rank == 0:
    # sequential solution
    start_seq = time.time()
    y0 = init
    seq_sol = solve_ivp_getdp(fun_f, np.array([t_range[0], t_range[-1]]), y0, **getdp_opts_f)
    end_seq = time.time()

    t_seq = end_seq - start_seq
    ratio_t = t_seq/t_parallel

    print(f"Parareal took {t_parallel} s, sequential took {t_seq} s. Speedup is {ratio_t}.")

    quantities = ["temperature", "field", "voltage", "resistiveHeating", "magneticEnergy"]
    fig, ax = plt.subplots(len(quantities))
    fig.subplots_adjust(hspace=1)
    fig.set_figheight(2.4 * 3)
    fig.set_figwidth(4 * 3)

    for idx, quantity in enumerate(quantities):
        ax[idx].set_title(f"{quantity}")
        ax[idx].plot(parallel_sol.t_events[idx], parallel_sol.y_events[idx], '*', label='parallel')
        ax[idx].plot(seq_sol.t_events[idx], seq_sol.y_events[idx], color="green")

    plt.legend()
    plt.xlabel("Time")
    plt.ylabel("Solution")
    now = datetime.now().strftime("%H_%M_%S")
    plt.savefig(os.path.join(output_path, now))
    
    np.savetxt(os.path.join(output_path, "y_par_" + now + ".csv"), parallel_sol.y_events, delimiter=',')
    np.savetxt(os.path.join(output_path, "t_par_" + now + ".csv"), parallel_sol.t_events, delimiter=',')
    np.savetxt(os.path.join(output_path, "err_par_" + now + ".csv"), parallel_sol.err, delimiter=',')

    np.savetxt(os.path.join(output_path, "y_seq_" + now + ".csv"), seq_sol.y_events, delimiter=',')
    np.savetxt(os.path.join(output_path, "t_seq_" + now + ".csv"), seq_sol.t_events, delimiter=',')

plt.show()