# run the forward model
import dolfin as dl
import numpy as np
import timeit
import cpp_codes

# load c++ codes for hydraulic functions
cpp_code_theta_VGM = cpp_codes.cpp_code_theta_VGM
cpp_code_dtheta_VGM = cpp_codes.cpp_code_dtheta_VGM
cpp_code_K_VGM = cpp_codes.cpp_code_K_VGM
cpp_code_dK_VGM = cpp_codes.cpp_code_dK_VGM

cpp_code_theta_PDI = cpp_codes.cpp_code_theta_PDI
cpp_code_dtheta_PDI = cpp_codes.cpp_code_dtheta_PDI
cpp_code_K_PDI = cpp_codes.cpp_code_K_PDI
cpp_code_dK_PDI = cpp_codes.cpp_code_dK_PDI


# define the forward model
def forward_model(hydraulic_parameters, hydraulic_model, T_max, time_interval, depth_list):

    N_s = 400 # 101 points
    X = 10.0 # soil depth in cm
    mesh = dl.IntervalMesh(N_s, -X, 0)
    degree = 1 # linear basis functions
    Vh = dl.FunctionSpace(mesh, "Lagrange", degree)

    # define temporal step
    T = T_max # minutes
    N_t = int(T_max/time_interval) # the number of temporal intervals
    mesh_m = dl.IntervalMesh(N_t, 0, T)
    Vm = dl.FunctionSpace(mesh_m, 'Lagrange', 1)

    simulation_times = np.linspace(0.0, T, num = N_t + 1)

    # initial condition
    psi_ini_value = -10**6.8 # cm
    psi_ini = dl.Expression('psi_ini', degree = 1, psi_ini = psi_ini_value)
    psi_ini = dl.interpolate(psi_ini, Vh)

    # initialize hydraulic models
    if hydraulic_model == "VGM":
        theta_model = dl.CompiledExpression(dl.compile_cpp_code(cpp_code_theta_VGM).theta_VGM(psi_ini.cpp_object()), degree=1)
        dtheta_model = dl.CompiledExpression(dl.compile_cpp_code(cpp_code_dtheta_VGM).dtheta_VGM(psi_ini.cpp_object()), degree=1)
        K_model = dl.CompiledExpression(dl.compile_cpp_code(cpp_code_K_VGM).K_VGM(psi_ini.cpp_object()), degree=1)
        dK_model = dl.CompiledExpression(dl.compile_cpp_code(cpp_code_dK_VGM).dK_VGM(psi_ini.cpp_object()), degree=1)

    if hydraulic_model == "PDI":
        theta_model = dl.CompiledExpression(dl.compile_cpp_code(cpp_code_theta_PDI).theta_PDI(psi_ini.cpp_object()), degree=1)
        dtheta_model = dl.CompiledExpression(dl.compile_cpp_code(cpp_code_dtheta_PDI).dtheta_PDI(psi_ini.cpp_object()), degree=1)
        K_model = dl.CompiledExpression(dl.compile_cpp_code(cpp_code_K_PDI).K_PDI(psi_ini.cpp_object()), degree=1)
        dK_model = dl.CompiledExpression(dl.compile_cpp_code(cpp_code_dK_PDI).dK_PDI(psi_ini.cpp_object()), degree=1)


    def set_parameters(model):
        if hydraulic_model == "VGM":
            model.theta_r = hydraulic_parameters[0]
            model.theta_s = hydraulic_parameters[1]
            model.alpha = hydraulic_parameters[2]
            model.n = hydraulic_parameters[3]
            model.K_s = hydraulic_parameters[4]*60 # change the unit from seconds to min
            model.tau = 0.5

        if hydraulic_model == "PDI":
            model.theta_r = hydraulic_parameters[0]
            model.theta_s = hydraulic_parameters[1]
            model.alpha = hydraulic_parameters[2]
            model.n = hydraulic_parameters[3]
            model.K_s = hydraulic_parameters[4]*60 # change the unit from seconds to min
            model.K_sf = hydraulic_parameters[5]*60 # change the unit from seconds to min
            model.slope = -1.5
            model.tau = 0.5


    def input_psi(psi):
        theta_model.u = psi
        dtheta_model.u = psi
        K_model.u = psi
        dK_model.u = psi

    set_parameters(theta_model)
    set_parameters(dtheta_model)
    set_parameters(K_model)
    set_parameters(dK_model)

    # Gauss quadrature with 1 point
    dx_1 = dl.dx(metadata={"quadrature_degree":1})

    def boundary_bottom(x, on_boundary):
        tol = 1E-14
        return abs(x[0] + X) < tol

    # lower boundary condition
    psi_lb = dl.Constant(-0.01)
    bc_lb = dl.DirichletBC(Vh, psi_lb, boundary_bottom)

    bc_lb.apply(psi_ini.vector())

    # boundary condition for Newton (Modified Picard) step
    bc_dpsi = dl.DirichletBC(Vh, dl.Constant(0.0), boundary_bottom)

    # flux boundary condition at the top
    m = dl.interpolate(dl.Expression('0.0', T = T, degree = 5), Vm)

    e = dl.Constant((1, )) # one-componenet vector for gravitational term

    # define test and trial functions
    psi_trial = dl.TrialFunction(Vh)
    psi_test = dl.TestFunction(Vh)
    dpsi_trial = dl.TrialFunction(Vh)

    M_form = psi_trial*psi_test*dl.dx # mass matrix
    N_form = psi_test*dl.ds  # vector for the Neumann boundary condition

    G_form = dl.inner(psi_trial*e, dl.grad(psi_test))*dx_1 # gravitational flow

    M = dl.assemble(M_form)
    N = dl.assemble(N_form)
    G = dl.assemble(G_form)
    N_ub = dl.assemble(N_form) # the bottom is not zero!

    bc_dpsi.apply(N_ub)

    M_2 = M # this is for making diagonal matrix

    def lump_matrix(form):
        mass_action_form = dl.action(form, dl.Constant(1))
        mass_lumped = dl.assemble(form)
        mass_lumped.zero()
        mass_lumped.set_diagonal(dl.assemble(mass_action_form))
        return mass_lumped

    M_lumped = lump_matrix(M_form)

    # parameters for Newton's method for forward problem
    tau_a = 1.0E-7 # absolute tol
    tau_r = 1.0E-7 # relative tol
    maxitr = 100

    verbose = False

    # ini_dt = 0.1
    ini_dt = T/1000
    min_dt = 0.0001
    max_dt = 5.0
    lower_iteration = 5
    upper_iteration = 9
    lower_mutiplier = 1.3
    upper_mutiplier = 0.7
    max_iteration = 100

    method = "Newton"
    # method = "modified Picard"

    if method == "Newton":
        line_serach = True
    else:
        line_serach = False
    # line_serach = True

    # solve the Richards equation
    """
    For given hydraulic parameters, solve the forward model for psi.
    The system of non-linear equations is solved by Newton's method or the modified Picard's method
    """
    psi = [dl.Function(Vh) for i in range(N_t+1)]
    theta = [dl.Function(Vh) for i in range(N_t+1)]

    psi[0].assign(psi_ini)
    theta[0].assign(dl.interpolate(theta_model, Vh))

    psi_k = dl.Function(Vh) # psi at k th iteration
    theta_k = dl.Function(Vh)
    dtheta_k = dl.Function(Vh)
    K_k = dl.Function(Vh)
    dK_k = dl.Function(Vh)

    psi_k.assign(psi[0])
    theta_k.assign(theta[0])
    dtheta_k.assign(dl.interpolate(dtheta_model, Vh))
    K_k.assign(dl.interpolate(K_model, Vh))
    dK_k.assign(dl.interpolate(dK_model, Vh))

    psi_k_prev = dl.Function(Vh) # previous iteration for Armijo line search
    theta_prev = dl.Function(Vh) # previous time step

    theta_prev.assign(theta[0])

    start_time = timeit.default_timer()

    list_residual_list = []
    iteration_list = []
    list_LS_list = []
    t_list = []

    time_index = 0
    record_time = 0
    dt_new = ini_dt
    dt_old = ini_dt
    dt = dl.Constant(0)
    t = 0
    iteration = 0

    while (t + dt_new < T + 0.000001) or (time_index < N_t): # this is not beutiful... need to think about how to implement time step
        """
        Newton iteration with Armijo linear search using Gaussian elimination.
        Refer to newton_LU algorithm and newton_armijo on Kelley2018.

        Modified Picard using Gaussian elimination.
        Refer to Chapter 4 of Kelley1995, Scudeler2016, and Lehmann1998.
        """


        iteration = 0
        total_line = 0

        residual_list = []
        LS_list = []

        if time_index < N_t:
            dt_new = min(dt_new, simulation_times[time_index + 1] - t)
        t += dt_new

        dt.assign(dt_new)

        # evaluate F(x) at the initial iterate

        q = dl.Constant(m(dl.Point(t))) # this is flux at the top boundary (positive inward; rainfall is positive)

        M_0 = M_lumped*theta_prev.vector() # from previous time step compoenent
        M_1 = M_lumped*theta_k.vector() # this is current water content

        b_1_form = dt*dl.inner(K_k*e, dl.grad(psi_test))*dx_1 # gravitational flow
        b_1 = dl.assemble(b_1_form)
        b_2 = dt*q*N_ub # flux from the boundary

        H_form = dt*dl.inner(K_k*dl.grad(psi_k), dl.grad(psi_test))*dx_1

        F = dl.assemble(H_form) + (M_1 - M_0) + b_1 - b_2

        bc_dpsi.apply(F)

        # update tolerance
        vertex_values_residual = F.get_local()
        error_old = np.max(np.abs(vertex_values_residual))
        residual_list.append(error_old)

        tol = tau_r * error_old + tau_a

        dpsi = dl.Function(Vh)

        while error_old > tol:
            if method == "Newton":
                # Newton: evaluate Jacobian matrix

                H_form = dl.inner(K_k*dl.grad(dpsi_trial), dl.grad(psi_test))*dx_1
                H_form2 = dl.inner(dpsi_trial*dl.grad(psi_k), dl.grad(psi_test))*dx_1
                H_2 = dl.assemble(H_form2)

                G_2 = G.copy()

                H_2_petc = dl.as_backend_type(H_2).mat()
                G_petc = dl.as_backend_type(G_2).mat()

                H_2_petc.diagonalScale(R = dK_k.vector().vec())
                G_petc.diagonalScale(R = dK_k.vector().vec())

                H = dt*(dl.assemble(H_form) + H_2 + G_2)

            else:
                # Modified Picard: evaluate Jacobian matrix (only for the volumetric water content at the next time step)
                H_form = dl.inner(K_k*dl.grad(dpsi_trial), dl.grad(psi_test))*dx_1
                H = dt*dl.assemble(H_form)
            M_2.zero()
            M_2.set_diagonal(M_lumped*dtheta_k.vector())
            J = H + M_2

            # solve the linear system
            bc_dpsi.apply(J, F)
            dl.solve(J, dpsi.vector(), -F)

             # Armijo line search
            alpha = 1.0E-4
            lam = 1.0
            descent = 0
            no_backtrack = 0
            psi_k_prev.assign(psi_k)
            while descent == 0 and no_backtrack < 100:
                psi_k.vector().axpy(lam, dpsi.vector())
                # input the updated psi into the soil hydraulic models
                input_psi(psi_k)

                # update iterations
                theta_k.assign(dl.interpolate(theta_model, Vh))
                dtheta_k.assign(dl.interpolate(dtheta_model, Vh))
                K_k.assign(dl.interpolate(K_model, Vh))
                if method == "Newton":
                    dK_k.assign(dl.interpolate(dK_model, Vh))

                # evaluate the residual
                M_1 = M_lumped*theta_k.vector()
                b_1_form = dt*dl.inner(K_k*e, dl.grad(psi_test))*dx_1 # gravitational flow
                b_1 = dl.assemble(b_1_form)
                H_form = dt*dl.inner(K_k*dl.grad(psi_k), dl.grad(psi_test))*dx_1
                F = dl.assemble(H_form) + (M_1 - M_0) + b_1 - b_2

                bc_dpsi.apply(F)
                # update tolerance
                vertex_values_residual = F.get_local()
                error_new = np.max(np.abs(vertex_values_residual))

                if not line_serach:
                    error_old = error_new
                    break
                # check if Armijo conditions are satisfied
                if error_new < error_old - error_old*alpha*lam:
                    error_old = error_new
                    descent = 1

    #                         print("Line Search Succsess")
                else:
                    no_backtrack += 1
                    lam *= 0.5
                    psi_k.assign(psi_k_prev)  # reset psi_k

            # print("The number of line searches is ", no_backtrack, "; t is ", t)
            if descent == 0:
                print("The redisual did not decline after 100 line searches.")
                return 1000, 1000

            if iteration > max_iteration:
                print(f"Newton did not converge after {iteration} iterations.")
                return 2000, 2000


            residual_list.append(error_old)
            LS_list.append(no_backtrack)

            total_line += no_backtrack
            iteration += 1



        if verbose:
            print("At t = ", f"{t:.4f}", ", the ", method, " iteration took ", iteration, "iterations with ", total_line, " Armijo line search")

        list_residual_list.append(residual_list)
        iteration_list.append(iteration)
        list_LS_list.append(LS_list)

        theta_prev.assign(theta_k)

        if abs(t - simulation_times[time_index + 1]) < 0.000001:

            psi[time_index + 1].assign(psi_k)
            theta[time_index + 1].assign(theta_k)

            time_index += 1
            dt_new = dt_old # reset the time step

        # check if the water reaches to the top (contradicts to the upper boundary condition)
        if theta_k.compute_vertex_values(mesh)[::-1][0] > theta[0].compute_vertex_values(mesh)[::-1][0] + 0.1:
            print("Water reached the top of the box")
            return 3000, 3000

        # update dt
        dt_old = dt_new
        if iteration < lower_iteration:
            dt_new = dt_old * lower_mutiplier
        elif iteration > upper_iteration:
            dt_new = dt_old * upper_mutiplier
        elif iteration > max_iteration:
            dt_new = dt_old * 0.333
        else:
            dt_new = dt_old

        t_list.append(t)

    time_fwd = timeit.default_timer() - start_time
    if verbose:
        print("It took", time_fwd, " seconds to solve the forward problem.")


#     psi_FEniCS = np.zeros((N_t + 1, N_s + 1))
    theta_FEniCS = np.zeros((N_t + 1, N_s + 1))
    for i in range(N_t + 1):
#         psi_FEniCS[i, :] = psi[i].compute_vertex_values(mesh)[::-1]
        theta_FEniCS[i, :] = theta[i].compute_vertex_values(mesh)[::-1]

    z_FEniCS = np.flip(mesh.coordinates().flatten())
    t_FEniCS = mesh_m.coordinates().flatten()

    index_list = []
    for depth in depth_list:
        index_list.append(np.where(z_FEniCS == depth)[0][0])

    estimated_theta = theta_FEniCS[:, index_list].flatten()

    # compute water mass in the domain
    def water_mass(time_index):
        mass = 0
        z = np.flip(mesh.coordinates().flatten())
        theta_values = np.flip(theta[time_index].compute_vertex_values())
        for i in range(len(z) - 1):
            dz = z[i] - z[i + 1]
            mass += 1/2*dz*(theta_values[i] + theta_values[i + 1])
        return mass

    water_input_list = []
    for i in range(len(simulation_times)):
        water_input_list.append(water_mass(i) - water_mass(0))

    estimated_flux = np.array(water_input_list)

    return estimated_theta, estimated_flux
