from scipy import integrate
from getdp_runner import solve_ivp_getdp
from scipy.integrate import solve_ivp
import numpy as np
from copy import deepcopy
import matplotlib.pyplot as plt
import multiprocessing

from mpi4py import MPI
from joblib import Parallel, delayed
from datetime import datetime
import os
import sys


def regenerate_plot(ax, idx, sol, n_time_windows, trange_time_windows, colormap, output_folder, style="-", savefig=False):
    """
    Redraws the matplotlib plot in ax using solution sol.
    :param ax: axis object to update
    :param sol: ODE solution object to get solution from
    :param n_time_windows: number of time intervals of parareal
    :param trange_time_windows: the time range of each time interval
    :param colormap: the colormap to be used
    """
    for i in range(n_time_windows):
        ax.plot(sol[i].t_events[idx], sol[i].y_events[idx], style, color=colormap[i], linewidth=1, markersize=5)
        ax.relim()
        ax.autoscale()
        # draw vertical lines to mark time intervals, redrawn every time since remove_lines removes all lines
        if i != 0:
            ax.axvline(trange_time_windows[i, 0], ls='--', color='grey', linewidth=1)

    plt.pause(0.1)  # needed for the dynamic update
    now = datetime.now().strftime("%H_%M_%S")
    if savefig:
        plt.savefig(os.path.join(output_folder, "plots", now + ".pdf"), bbox_inches='tight', format="pdf")


def remove_lines(ax):
    """
    Remove all lines in the axis object ax.
    :param ax: axis object to remove lines from
    """
    while ax.lines:
        ax.lines[0].remove()

def solve_ivp_parareal(trange, y0, fun_c, solver_c, optionsC, fun_f, solver_f, optionsF, paraOpt, optionsCFirst=None):
    """
    Solve initial value problem using Parareal.
    :param trange: numpy.array([starting time, end time])
    :param y0: numpy array with initial solution
    :param fun_c: function to integrate for coarse integrator
    :param solver_c: method for coarse integrator
    :param optionsC: options for coarse integrator
    :param fun_f: function to integrate for fine integrator (i.e., the .pro file)
    :param solver_f: method for fine integrator
    :param optionsF: options for fine integrator
    :param paraOpt: Parareal options
    """
    
    if "ref_sol_path" in paraOpt:
        if paraOpt["ref_sol_path"] is not None:
            t_ref = np.loadtxt(paraOpt["ref_sol_path"], delimiter=",", usecols=(0,))
            y_ref = np.loadtxt(paraOpt["ref_sol_path"], delimiter=",", usecols=(1,))

            if "ref_err_tol" in paraOpt:
                ref_err_tol = paraOpt["ref_err_tol"]
            else:
                raise ValueError("ref_err_tol must be given if ref_sol_path is given")
                
        else:
            t_ref = None
            y_ref = None
    else:
        t_ref = None
        y_ref = None


    # get options
    if "n_time_windows" in paraOpt:
        n_time_windows = paraOpt["n_time_windows"]
    else:
        # default
        n_time_windows = 8

    if "NIter" in paraOpt:
        NIter = paraOpt["NIter"]
    else:
        # default
        NIter = 20

    if "NProc" in paraOpt:
        NProc = paraOpt["NProc"]
    else:
        # default
        NProc = multiprocessing.cpu_count()

    if "plotEach" in paraOpt:
        plotEach = paraOpt["plotEach"]
    else:
        # default
        plotEach = False

    if "joblib_parallel" in paraOpt:
        joblib_parallel = paraOpt["joblib_parallel"]
    else:
        # default
        joblib_parallel = False

    if "first_coarse_adaptive" in paraOpt:
        first_coarse_adaptive = paraOpt["first_coarse_adaptive"]

        if first_coarse_adaptive:
            if optionsCFirst is None:
                raise ValueError("optionsCFirst must be given if first_coarse_adaptive is True")
            
            if "include_breakpoints" in paraOpt:
                breakpoints = paraOpt["include_breakpoints"]
            else: 
                breakpoints = None

            optionsCFirst["type"] = "coarse"

            if "rearrange_according_to" in paraOpt:
                rearrange_according_to = paraOpt["rearrange_according_to"]
            else:
                rearrange_according_to = "timestep"
    else:
        # default
        first_coarse_adaptive = False

    optionsC["type"] = "coarse"
    optionsF["type"] = "fine"

    comm = MPI.COMM_WORLD
    if comm.Get_size() > 1:
        mpi_parallel = True
        size = comm.Get_size()
        rank = comm.Get_rank()

        if size != NProc:
                raise ValueError("NProc must be equal to the number of MPI processes")
    else:
        # default
        mpi_parallel = False

    if "abs_tol" in paraOpt:
        abs_tol = paraOpt["abs_tol"]
    else:
        # default
        abs_tol = 1E-10
          
    if "rel_tol" in paraOpt:
        rel_tol = paraOpt["rel_tol"]
    else:
        # default
        rel_tol = 1E-10

    if "output_folder" in paraOpt:
        output_folder = paraOpt["output_folder"]
    else:
        # default
        output_folder = os.getcwd()

    if (not mpi_parallel or (mpi_parallel and rank == 0)):
        if not os.path.exists(output_folder):
            os.makedirs(output_folder)

        if not os.path.exists(os.path.join(output_folder, "plots")):
            os.makedirs(os.path.join(output_folder, "plots"))


    if (y0.shape[0] == 1 and y0.shape[1] != 1):
        y0 = np.transpose(y0)




    if plotEach and (not mpi_parallel or (mpi_parallel and rank == 0)):
        SMALL_SIZE = 6
        MEDIUM_SIZE = 8
        BIGGER_SIZE = 10

        plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
        plt.rc('axes', titlesize=MEDIUM_SIZE)     # fontsize of the axes title
        plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
        plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
        plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
        plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
        plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

        colors = plt.cm.viridis(np.linspace(0, 0.9, n_time_windows))  # do not use complete yellow spectrum of viridis

        quantities = ["temperature", "field", "voltage", "resistiveHeating", "magneticEnergy"]
        fig, ax = plt.subplots(len(quantities), 2)
        fig.subplots_adjust(hspace=1)
        fig.set_figheight(2.4 * 3)
        fig.set_figwidth(4 * 3)

        for i, quantity in enumerate(quantities):
            ax[i, 0].set_title(f"{quantity} Coarse Propagator")
            ax[i, 0].set_xlabel("Time")
            ax[i, 0].set_ylabel(f"{quantity}")
            ax[i, 0].plot([], [])

            ax[i, 1].set_title(f"{quantity} Fine Propagator")
            ax[i, 1].set_xlabel("Time")
            ax[i, 1].set_ylabel(f"{quantity}")
            ax[i, 1].plot([], [])

        plt.pause(1)

    y_f0 = np.zeros((n_time_windows + 1, len(y0)))  # matrix containing the fine solution for proc i+1 at t_0
    y_fend = np.zeros((n_time_windows + 1, len(y0)))  # matrix containing the fine solution for proc i at t_end
    y_event_f0 = np.zeros((n_time_windows + 1, 1))  # matrix containing the fine solution for proc i+1 at t_0
    y_event_fend = np.zeros((n_time_windows + 1, 1))  # matrix containing the fine solution for proc i at t_end
    err = np.zeros((NIter + 1, 1))

    y_cnew = np.zeros((n_time_windows + 1, len(y0)))  # matrix containing the new coarse solution
    y_tot = np.zeros((n_time_windows + 1, len(y0)))  # matrix containing the initial values
    sol_c = []
    sol_f = []

    t_all = np.array([[]])
    y_all = np.array([[]])
    sol_c_all = np.array([[]])
    t_c_all = np.array([[]])

    y_tot[0, :] = y0

    if len(trange) == 2: # same time step size for each 
        equiTRange = np.linspace(trange[0], trange[1], n_time_windows + 1)
        idx = np.array((range(0, n_time_windows), range(1, n_time_windows + 1))).transpose()
        trange_time_windows = equiTRange[idx]
    else: 
        trange_time_windows = np.array(list(zip(trange[:-1], trange[1:])))
    
    k = 0

    while ((k < 1 or err[k - 1] > 1)):

        k = k + 1

        if not mpi_parallel or (mpi_parallel and rank == 0):
           
            print("\nPARAREAL: START %d/%d ITERATION \n" % (k, NIter), flush=True)
            print('\nPARAREAL: start coarse grid\n', flush=True)  
            

            sol_c = []
            sol_f = []

            if first_coarse_adaptive and k == 1:
                # first coarse solution is computed with adaptive solver
                print('Parareal: adaptive coarse integration in t = [%f, %f]' % (trange_time_windows[0, 0], trange_time_windows[-1, 1]), flush=True)
                

                sol_c_helper = globals()[solver_c](fun_c, np.array([trange_time_windows[0, 0], trange_time_windows[-1, 1]]), y_tot[0], **optionsCFirst)
                
                print(f'Parareal: adaptive coarse integration took {len(sol_c_helper.t)} steps', flush=True)
                

                # recreate sol_c
                if rearrange_according_to == "time":
                    t_tracker = sol_c_helper.nfev
                    total_time = np.sum(t_tracker)
                    normalized_t_tracker = np.cumsum(t_tracker)/total_time * n_time_windows
                    split_idx = np.array([np.argmax(normalized_t_tracker > i + 1) for i in range(n_time_windows - 1)])
                    # on purpose always calculates np.diff in every step
                    # avoid duplicate entries in split_idx
                    for i in range(len(split_idx) - 1):
                        if np.diff(split_idx)[i] <= 0:
                            split_idx[i+1] = split_idx[i+1] + 1 + abs(np.diff(split_idx)[i])

                else:
                    split_idx = len(sol_c_helper.t)/n_time_windows * (np.arange(n_time_windows - 1) + 1)

                if breakpoints: 
        
                    breakpoint_idx = []

                    for breakpoint in breakpoints:
                        breakpoint_idx.append(np.where(sol_c_helper.t == breakpoint))

                    breakpoint_idx = np.array(breakpoint_idx).flatten()
                    replacement_idx_list = []
                    for idx in breakpoint_idx:
                        replacement_idx = np.argmin(np.abs(split_idx - idx))

                        if replacement_idx not in [replacement_i for (_, replacement_i) in replacement_idx_list]:
                            replacement_idx_list.append((idx, replacement_idx))
                        else:
                            replacement_idx_list.append((idx, replacement_idx + 1))

                    for (idx, replacement_idx) in replacement_idx_list:
                        split_idx[replacement_idx] = idx    

                optionsC["time_step_list"] = np.diff(sol_c_helper.t)
                t_helper = np.array_split(sol_c_helper.t, split_idx.astype(int))
                y_helper = np.array_split(sol_c_helper.y, split_idx.astype(int), axis=1)
                t_events_helper = np.array_split(sol_c_helper.t_events, split_idx.astype(int), axis=1)
                y_events_helper = np.array_split(sol_c_helper.y_events, split_idx.astype(int), axis=1)

                trange = [t_helper[i][0] for i in range(n_time_windows)] + [t_helper[-1][-1]]
                trange_time_windows = np.array(list(zip(trange[:-1], trange[1:])))

                # split solution into time windows
                for i in range(0, n_time_windows):
                    if i < n_time_windows - 1:
                        t = np.append(t_helper[i], t_helper[i + 1][0])
                        y = np.append(y_helper[i], y_helper[i + 1][:, 0:1], axis=1)
                        t_events = np.append(t_events_helper[i], t_events_helper[i + 1][:,0:1], axis = 1)
                        y_events = np.append(y_events_helper[i], y_events_helper[i + 1][:,0:1], axis = 1)
                    else:
                        t = t_helper[i]
                        y = y_helper[i]
                        t_events = t_events_helper[i]
                        y_events = y_events_helper[i]
                        
                    sol_c_part = integrate._ivp.ivp.OdeResult(t=t, y=y, t_events=t_events, y_events=y_events)

                    sol_c.append(sol_c_part)

                print(f"Parareal: using time range {trange}", flush=True)
                

                for i in range(0, n_time_windows):
                    y_cold = deepcopy(y_cnew[i + 1, :])
                    y_cnew[i + 1, :] = sol_c[i].y[:, -1]
                    y_tot[i + 1, :] = y_cnew[i + 1, :] + y_fend[i + 1, :] - y_cold

            else:
                sol_c = []
                for i in range(0, n_time_windows):

                    print(f"Parareal: coarse integration in t = {trange_time_windows[i, :]}", flush=True)
                    

                    sol_c.append(globals()[solver_c](fun_c, trange_time_windows[i, :], y_tot[i], **optionsC))

                    # ATTENTION deepcopy needed otherwise points to the same list object
                    y_cold = deepcopy(y_cnew[i + 1, :])

                    y_cnew[i + 1, :] = sol_c[i].y[:, -1]

                    y_tot[i + 1, :] = y_cnew[i + 1, :] + y_fend[i + 1, :] - y_cold

    
        if plotEach and (not mpi_parallel or (mpi_parallel and rank == 0)):
            for idx, quantity in enumerate(quantities):
                remove_lines(ax[idx, 0])
                loop_last = idx == len(quantities) - 1
                regenerate_plot(ax[idx, 0], idx, sol_c, n_time_windows, trange_time_windows, colors, output_folder, style="-o", savefig=loop_last)

        if mpi_parallel and first_coarse_adaptive and k == 1:
            trange = comm.bcast(trange, root=0) 
            trange_time_windows = comm.bcast(trange_time_windows, root=0) 

        if not mpi_parallel or (mpi_parallel and rank == 0):
            print('\nPARREAL: start fine grid\n', flush=True)
            sol_f = []
            
            
        if joblib_parallel and not mpi_parallel:
            sol_f = Parallel(n_jobs=NProc, prefer="threads")(
               delayed(globals()[solver_f])(fun_f, trange_time_windows[i, :], y_tot[i], **optionsF) for i in range(n_time_windows))
        elif mpi_parallel:
            # possible improvement: use capital letter MPI functions for Numpy arrays
            y_init = comm.scatter(y_tot[:-1], root=0) 
            # parallel solve 
            f_sol_helper = globals()[solver_f](fun_f, trange_time_windows[rank, :], y_init, **optionsF) 

            sol_f = comm.gather(f_sol_helper, root=0)  
        else:
            # this could be a parallel loop!
            for i in range(0, n_time_windows):

                f_sol_helper = globals()[solver_f](fun_f, trange_time_windows[i, :], y_tot[i], **optionsF)
                sol_f.append(f_sol_helper)

        if not mpi_parallel or (mpi_parallel and rank == 0):
            if plotEach:
                for idx, quantity in enumerate(quantities):
                    remove_lines(ax[idx, 1])
                    loop_last = idx == len(quantities) - 1
                    regenerate_plot(ax[idx, 1], idx, sol_f, n_time_windows, trange_time_windows, colors, output_folder, savefig=loop_last)

            # reconstruct global t and y vectors
            t = np.array([trange[0]])

            y = np.array([y0])
            y = np.transpose(y)
            y_c = np.transpose(np.array([y0]))
            t_c = np.array([trange[0]])
            t_events = sol_f[0].t_events[:]
            y_events = sol_f[0].y_events[:]
            t_events_c = sol_c[0].t_events[:]
            y_events_c = sol_c[0].y_events[:]

            for i in range(0, len(sol_f)):
                if i!= 0:
                  t_events = np.concatenate((t_events, sol_f[i].t_events[:]), axis=1)
                  y_events = np.concatenate((y_events, sol_f[i].y_events[:]), axis=1)
                  t_events_c = np.concatenate((t_events_c, sol_c[i].t_events[:]), axis=1)
                  y_events_c = np.concatenate((y_events_c, sol_c[i].y_events[:]), axis=1)
                t = np.concatenate((t, sol_f[i].t[1:]))
                y = np.concatenate((y, sol_f[i].y[:, 1:]), axis=1)
                y_c = np.concatenate((y_c, sol_c[i].y[:, 1:]), axis=1)
                t_c = np.concatenate((t_c, sol_c[i].t[1:]))

            # save event solutions to files
            np.savetxt(os.path.join(output_folder, "t_events" + str(k) + ".txt"), t_events.T)
            np.savetxt(os.path.join(output_folder, "t_events_c" + str(k) + ".txt"), t_events_c.T)
            np.savetxt(os.path.join(output_folder, "y_events" + str(k) + ".txt"), y_events.T)
            np.savetxt(os.path.join(output_folder, "y_events_c" + str(k) + ".txt"), y_events_c.T)

            # compare solutions at start and end of each time interval
            for i in range(1, n_time_windows):
                y_fend[i, :] = sol_f[i - 1].y[:, -1]
                y_f0[i, :] = sol_f[i].y[:, 0]
                # improvement: generalize convergence index as user input
                convergence_idx = 0
                y_event_fend[i] = sol_f[i - 1].y_events[convergence_idx,-1]
                y_event_f0[i] = sol_f[i].y_events[convergence_idx,0]

            y_fend[-1, :] = sol_f[-1].y[:, -1]
            y_event_fend[-1] = sol_f[-1].y_events[convergence_idx,-1]

            # err[k - 1] = para_AbsRel_norm(y_fend[0:-1, :], y_f0[0:-1, :], abs_tol, rel_tol)
            err[k - 1], win = para_AbsRel_norm_singleQOI(y_event_fend[0:-1], y_event_f0[0:-1], abs_tol, rel_tol)
            print(f"PARAREAL: ERROR after {k} ITERATIONS is {err[k - 1]} in windows {win} \n", flush=True)

            if (first_coarse_adaptive and k > NIter + 1) or (not first_coarse_adaptive and k > NIter):    
                raise Exception("Parareal DID NOT CONVERGE after %d iterations with an error norm of %f\n" % (k, err[k - 1]))
        

        if not mpi_parallel or (mpi_parallel and rank == 0):

            if y_ref is not None and t_ref is not None:

                par_interp = np.interp(t_ref, t_events[convergence_idx,:], y_events[convergence_idx,:])
                abs_err_to_ref = np.max(np.abs(par_interp - y_ref))
                print("PARAREAL: MAX ABSOLUTE ERROR TO REFERENCE SOLUTION: %f\n" % abs_err_to_ref, flush=True)

                # overwrite error
                err[k - 1] = 2

                if abs_err_to_ref/ref_err_tol < 1.1:
                    print("PARAREAL CONVERGED SUCCESSFULLY after %d iterations." % k, flush=True) 
                    err[k-1] = 0.1

            elif err[k - 1] < 1:
                print("PARAREAL CONVERGED SUCCESSFULLY after %d iterations with an error norm of %f\n" % (k, err[k - 1]), flush=True) 

        if mpi_parallel:
            err = comm.bcast(err, root=0)

        

    if not mpi_parallel or (mpi_parallel and rank == 0):
        ode_sol_obj = integrate._ivp.ivp.OdeResult(t=t, y=y, err=err, trange=trange, y_all=y_all, t_all=t_all, sol_c_all=sol_c_all, t_c_all=t_c_all, y_events=y_events, t_events=t_events)
        return ode_sol_obj
    else: 
        return None

def para_AbsRel_norm_singleQOI(x, y, AbsTol, RelTol):
    sc = AbsTol + abs(x) * RelTol
    M = np.divide(abs(x - y), sc)
    infNorm = max(M)
    return infNorm, np.argmax(M)

def para_AbsRel_norm(x, y, AbsTol, RelTol):
    # inspired by Hairer, Norsett, Wanner
    # Solving Ordinary Equations 1, page 167
    sc = AbsTol + abs(x) * RelTol
    M = np.divide(abs(x - y), sc)
    twoNorm = np.sqrt(np.sum(np.square(M), axis=1)) / np.sqrt(x.shape[1])
    infNorm = max(twoNorm)  # infinity norm of the 2norm
    return infNorm