import os
import subprocess
import numpy as np
import tempfile

from scipy import integrate
import time

from getdp_read_preresolution import getdp_read_preresolution

def solve_ivp_getdp(fun, t_span, y0, **getdp_options):
    """
    Solve an initial value problem using GetDP.
    :param fun: function to integrate (i.e., GetDP pro_file)
    :param t_span: 2-member sequence with interval of integration (t_start, t_end)
    :param y0: array_like, shape (n,), initial state
    :param getdp_options: options for GetDP
    """

    wall_time_start = time.time()

    pro_file = fun

    if not os.path.isfile(pro_file):
        raise Exception("GetDP pro file not found.")
    
    if "complex_getdp" in getdp_options:
        complex_getdp = getdp_options["complex_getdp"]
    else:
        complex_getdp = False

    if "type" in getdp_options:
        propagator_type = getdp_options["type"]
    else:
        propagator_type = "unknown"

    if "exe" in getdp_options:
        getdp_exe = getdp_options["exe"]

        if subprocess.run([getdp_exe, '--version'], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL).returncode:
            raise Exception("GetDP executable cannot be called, please check path.")
    else:
        raise Exception("GetDP executable not provided.")

    if "mesh" in getdp_options:
        mesh_file = getdp_options["mesh"]

        if not os.path.isfile(mesh_file):
            raise Exception("GetDP mesh not found.")
    else:
        raise Exception("Please provide a Gmsh mesh.")
    

    if "pre_processing" in getdp_options:
        pre_processing = getdp_options["pre_processing"]
    else:
        pre_processing = "#1"

    if "postop" in getdp_options:
        postop = getdp_options["post-postop"]
    else:
        postop = f"#1"

    if "verbose" in getdp_options:
        verbose = getdp_options["verbose"]
    else:
        verbose = "0"

    if "nl_iteration" in getdp_options:
        nl_iteration = getdp_options["nl_iteration"]
    else:
        nl_iteration = 1

    if "adaptive_time" in getdp_options:
        adaptive_time = getdp_options["adaptive_time"]
    else:
        adaptive_time = 0

    if "rel_tol_adaptive" in getdp_options:
        rel_tol_adaptive = getdp_options["rel_tol_adaptive"]
    else: 
        rel_tol_adaptive = 1e-3

    if "abs_tol_adaptive" in getdp_options:
        abs_tol_adaptive = getdp_options["abs_tol_adaptive"]
    else:   
        abs_tol_adaptive = 1e-1

    if "rel_tol_nl" in getdp_options:
        rel_tol_nl = getdp_options["rel_tol_nl"]
    else: 
        rel_tol_nl = 1e-3

    if "abs_tol_nl" in getdp_options:
        abs_tol_nl = getdp_options["abs_tol_nl"]
    else:   
        abs_tol_nl = 1e-1

    if "n_max_it_nl" in getdp_options:
        n_max_it_nl = getdp_options["n_max_it_nl"]
    else:
        n_max_it_nl = 20

    if "relax_factor" in getdp_options:
        relax_factor = getdp_options["relax_factor"]
    else:
        relax_factor = 0.7

    if "t_step_init" in getdp_options:
        t_step_init = getdp_options["t_step_init"]
        if t_step_init > (t_span[1] - t_span[0]):
            t_step_init = (t_span[1] - t_span[0])/5
    else:   
        t_step_init = 1e-1    

    if "t_step_min" in getdp_options:
        t_step_min = getdp_options["t_step_min"]
    else:   
        t_step_min = 1e-12  
    
    if "t_step_max" in getdp_options:
        t_step_max = getdp_options["t_step_max"]
    else:   
        t_step_max = 1e-12  

    if "track_times" in getdp_options:
        if getdp_options["track_times"]:
            track_times_per_time_step = 1
    else:
        track_times_per_time_step = 0   

    # possible values: "all", "last", "none"
    if "write_solution_to_disk" in getdp_options:
        write_sol_disk = {"all": 1, "last": 2, "none": 0}

        if getdp_options["write_solution_to_disk"] in write_sol_disk:
            write_solution_to_disk = write_sol_disk[getdp_options["write_solution_to_disk"]]
        else:
            write_solution_to_disk = "all"

    init = y0
    if not isinstance(init, np.ndarray):
        raise Exception("Initial value not provided as a numpy array.")

    time_start = t_span[0]

    if time_start == 0:
        do_restart = False
    else:
        do_restart = True

    time_end = t_span[1]

    time_step_per_file = 0
    if not adaptive_time:
        if "time_step_list" in getdp_options:
            time_steps = getdp_options["time_step_list"]
            time_vals = np.append(0, np.cumsum(time_steps)[:-1])

            time_step_per_file = 1
            time_step = -1

        elif "time_step" in getdp_options:
            time_step = getdp_options["time_step"]
        elif "num_time_step" in getdp_options: 
            time_step = (time_end - time_start)/getdp_options["num_time_step"]
        else:
            raise Exception("Neither time step size nor number of time steps provided.")
    else:
        time_step = -1

    with tempfile.TemporaryDirectory() as tmp_dir:

        path, file_and_ext = os.path.split(pro_file)
        filename = os.path.splitext(file_and_ext)[0]

        tmpname = os.path.join(tmp_dir, filename)

        prefile = os.path.join(tmp_dir, filename + ".pre")
        resfile = os.path.join(tmp_dir, filename + ".res")
        pospath = os.path.join(tmp_dir, filename)

        timestep_file = ""
        if "time_step_list" in getdp_options and not adaptive_time:
            timestep_file = os.path.join(tmp_dir, "timestep.data")

            with open(timestep_file, 'w') as f:
                f.write(f"List_of_time_values = {{ {str(list(time_vals))[1:-1]} }}; \n")
                f.write(f"List_of_time_steps = {{ {str(list(time_steps))[1:-1]} }};")

        exe_list = [getdp_exe, pro_file,
                      "-pre", pre_processing,
                      "-msh", mesh_file,
                      "-name", tmpname,
                      "-v", verbose]

        if subprocess.run(exe_list).returncode:
            raise Exception("Pre-processing failed")

        # try to read number of DoF
        num_dof = getdp_read_preresolution(prefile)
        if num_dof != init.size:
            raise Exception("Initial value has wrong size")
        
        # create initial data 
        create_resolution(resfile, time_start, init, complex_getdp)

        exe_list = [getdp_exe, pro_file,
                "-msh", mesh_file,
                "-name", tmpname,
                "-v", verbose,
                "-setnumber", "t_end", str(time_end),
                "-setnumber", "t_step", str(time_step),
                "-setnumber", "Flag_nl_iteration", str(nl_iteration), 
                "-setnumber", "Flag_write_solution_to_disk", str(write_solution_to_disk),
                "-setnumber", "Flag_adaptive_time_stepping", str(adaptive_time),
                "-setnumber", "rel_tol_adaptive_time_stepping", str(rel_tol_adaptive),
                "-setnumber", "abs_tol_adaptive_time_stepping", str(abs_tol_adaptive),
                "-setnumber", "t_step_init", str(t_step_init),
                "-setnumber", "t_step_min", str(t_step_min),
                "-setnumber", "t_step_max", str(t_step_max),
                "-setnumber", "nl_relTol_norm", str(rel_tol_nl),
                "-setnumber", "nl_absTol_norm", str(abs_tol_nl),
                "-setnumber", "NMaxIt", str(n_max_it_nl),
                "-setnumber", "relaxFactor", str(relax_factor),
                "-setnumber", "Flag_trackTimesPerTimeStep", str(track_times_per_time_step),
                "-setnumber", "time_step_per_file", str(time_step_per_file),
                "-setstring", "timestep_file", timestep_file,
                "-pos", postop,
                "-mat_mumps_cntl_1", "0", 
                ]
        
        if do_restart:
            exe_list += [
                "-setnumber", "restartComputation", "1",
                "-restart",
                "-res", resfile
                ]
        else:
            exe_list += [
                "-setnumber",  "restartComputation", "0",
                "-solve", '#1'
                ]
                
        
        if subprocess.run(exe_list).returncode:
            raise Exception("Processing failed")
        
        time_res, y_vals_res = read_resolution(resfile, num_dof)
        t_plot_res, y_plot_res = read_pospath(pospath)
        if track_times_per_time_step:
            t_track = np.loadtxt(os.path.join(tmp_dir, "tracked_run_times.txt"), usecols = 4)
        else:
            t_track = None


        wall_time_end = time.time()
        time_needed = wall_time_end - wall_time_start
        print(f"solve_ivp_getdp: {propagator_type} solver from t = {time_start} s to t = {time_end} s took {time_needed} s")
        
        # return same as scipy would
        ode_sol_obj = integrate._ivp.ivp.OdeResult(t=time_res, y=y_vals_res, t_events=t_plot_res, y_events=y_plot_res, nfev = t_track)


        return ode_sol_obj


def create_resolution(res_file, time, val_array, complex_getdp):
    """
    Create resolution file. 
    :param res_file: .res file path
    :param time: time value (only one step)
    :param val_array: array_like, shape (n,), current state
    :param complex_getdp: boolean, GetDP compiled with complex Petsc
    """

    if np.isnan(np.min(val_array)):
        raise Exception("At least on value is NaN!")

    # atm: only one dofdata object treated
    header_string = f"$ResFormat /* solve_ivp_getdp copyright E. Schnaubelt, CERN/TU Darmstadt \n" + \
        "1.1 0\n" + \
        "$EndResFormat\n" + \
        "$Solution  /* DofData 0 */ \n" + \
        "0 %.16f 0 0" % (time) # dofdata-number time-value time-imag-value time-step-number
    
    footer_string="$EndSolution"

    with open(res_file, 'w') as f:
        if complex_getdp: 
            np.savetxt(f, np.c_[ val_array, np.zeros(len(val_array)) ], header=header_string, footer=footer_string, comments="")
        else:
            np.savetxt(f, val_array, header=header_string, footer=footer_string, comments="")

def read_resolution(res_file, num_dof): 
    """
    Read resolution file. 
    :param res_file: .res file path
    :param num_dof: number of degrees of freedom
    """

    with open(res_file, 'r') as f:
        fct_vals = []

        t = []

        t_idx = 0 
        old_t_idx = 0

        for line in f:
            if line.find("$Solution") >= 0:
                # dofdata-number time-value time-imag-value time-step-number
                meta_data = f.readline().split()

                dofdata = int(meta_data[0])
                time = float(meta_data[1])
                t_step = float(meta_data[3])

                # print(meta_data)

                if old_t_idx < t_step: # check if time step advances
                    old_t_idx = t_step
                elif old_t_idx > t_step: 
                    raise Exception(f"GetDP runner: error reading file {res_file}. Time step {t_step} is stored after {old_t_idx}.")
                else: 
                    pass
                    # warnings.warn(f"GetDP runner reading file {res_file}. Time step {t_step} == {old_t_idx} already read before.")
                
                t.append(time) 
                fct_vals.append([])

                for _ in range(num_dof):
                    fct_vals[t_idx].append(float(f.readline().split()[0]))
                    
                t_idx = t_idx + 1

            elif line.find("$ResFormat") >= 0:
                res_format = f.readline().split()[0]
                if res_format != "1.1": 
                     raise Exception(f"GetDP runner. unknown res file format version: {res_format}")

        fct_vals = np.array(fct_vals).transpose()
        t = np.array(t)

        if np.isnan(np.min(fct_vals)):
            raise Exception("At least on value is NaN!")
        
        return t, fct_vals
    
def read_pospath(pos_path): 
    """
    Read pos file saved as "Format TimeTable" in GetDP.
    :param pos_path: .pos file path
    """

    t_list = []
    y_list = []

    quantities = ["temperature", "field", "voltage", "resistiveHeating", "magneticEnergy"]
    skiprows_list = [0, 0, 0, 0, 0]
    t_idx = [0, 1, 0, 0, 0]
    y_idx = [-1, -1, -1, -1, -1]

    for idx, quantity in enumerate(quantities):
        pos_file = f"{pos_path}_{quantity}.pos"
        if os.path.isfile(pos_file):
            pos_array = np.loadtxt(pos_file, skiprows=skiprows_list[idx])
            t_list.append(pos_array[:, t_idx[idx]])
            y_list.append(pos_array[:, y_idx[idx]])

    t = np.vstack(t_list)
    y = np.vstack(y_list)
    return t, y