# -*- coding: utf-8 -*-
"""
This file is part of PyFrac.

Created by Brice Lecampion -
Copyright (c) "ECOLE POLYTECHNIQUE FEDERALE DE LAUSANNE, Switzerland, Geo-Energy Laboratory", 2016-2022.
All rights reserved. See the LICENSE.TXT file for more details.
"""

# External imports
import copy
import numpy as np
import os

# local imports
from mesh_obj.mesh import CartesianMesh
from solid.solid_prop import MaterialProperties
from fluid.fluid_prop import FluidProperties
from properties import InjectionProperties, SimulationProperties
from fracture_obj.fracture import Fracture
from controller import Controller
from fracture_obj.fracture_initialization import Geometry, InitializationParameters
from utilities.utility import setup_logging_to_console
from utilities.postprocess_fracture import load_fractures

run = True
new_run = True
post_process = False
export_results = True
plotting = False

# setting up the verbosity level of the log at console
setup_logging_to_console(verbosity_level='debug')

# creating mesh
Lr=1 # 0.425 # size of domain

estimated_Rtr = 0.756285
alpha = 2.5e-2

Ncells = round(2 * Lr / (alpha * estimated_Rtr))

Mesh = CartesianMesh(Lr, Lr, Ncells, Ncells)

# dimensionless simulation
# as per Peirce 2022
# we set all parameters to 1 such that t_m\tilde{m} =1
#
t_s=10   # :: omega = ts/t_m\tilde{m}  as t_m\tilde{m} =1 --> this is equivalent to setting omega
phi_v=1
K_prime=(t_s*(phi_v**(-65/9.)))**(1./26.)  # we set Kprime for the target value of phi_v and t_s

# fluid properties
viscosity = 1. / 12  # mu' =1.
Fluid = FluidProperties(viscosity=viscosity)
# injection parameters
Q0 = 1.0  # injection rate

Rate_history=np.asarray([[0.0, t_s],[Q0, 0.]])
Injection = InjectionProperties(Rate_history, Mesh)

# solid properties
Eprime = 1   # plain strain modulus
Cl = 0.5       # C'=2Cl=1 Carter's leak off coefficient
K1c =  K_prime/np.sqrt(32./np.pi)         #
# the minimym width for the contact algorithm
estimated_Wtr = 0.7571
beta = 1e-6
min_width = beta * estimated_Wtr


# material properties
Solid = MaterialProperties(Mesh,
                           Eprime,
                           K1c,
                           Carters_coef=Cl,
                           minimum_width=min_width)

# end time of simul
# lets put it very large and stop manually
t_end=14  # to be adjusted depending on omega and phi_c
# start time of simul frm M vertex
# we need to ensure that we are still in the M limit
t_start = max([min([0.05, 0.05 * Eprime ** (13/2) * (12 * viscosity) ** (5/2) * Q0 ** (3/2) / K_prime ** 9]),
               (8 * 2 * Lr / Ncells) ** (9/4) * (12 * viscosity) ** (1/4) / (Eprime ** (1/4) * Q0 ** (1/4))])

baseName = "MtoK_closure_w_1_phi_01-Nelts-"+str(Ncells)+"-Lr-"+str(Lr)+"-Wmin1-"+str(min_width) #"Radial_closure_om_1_phi_01-fine"
foldername= "./" + baseName + "/"

# simulation properties
simulProp = SimulationProperties()
simulProp.finalTime = t_end                           # the time at which the simulation stops
simulProp.saveTSJump, simulProp.plotTSJump = 1, 5     # save 1 and plot after every 5 time steps
simulProp.set_outputFolder(foldername)   # the disk address where the files are saved
simulProp.frontAdvancing = 'implicit'               # setting up predictor-corrector front advancing
simulProp.plotVar = ['footprint', 'regime']
simulProp.plotVar = ['w']
simulProp.enableRemeshing = False
simulProp.force_time_schedule = False
simulProp.set_solTimeSeries(np.concatenate((np.arange(t_start+0.01,10,0.05),np.arange(10,10.8,0.005),np.arange(10.8,t_end,0.005))))
simulProp.EHL_iter_lin_solve = False

if run :
    # initializing fractures
    Fr_geometry = Geometry('radial')
    init_param = InitializationParameters(Fr_geometry, regime='M', time=t_start)

    if new_run:
        # creating fracture object
        Fr = Fracture(Mesh,
                      init_param,
                      Solid,
                      Fluid,
                    Injection,
                    simulProp)
    else:
        # loading simulation results
        Fr_list, properties = load_fractures(address="/home/MtoK_closure_w_1_phi_01-Nelts-106-Lr-1-Wmin1-7.570999999999999e-07",step_size=1)  # load all fractures
        Solid, Fluid, Injection, simulProp = properties
        Fr = copy.deepcopy(Fr_list[-1])



    # create a Controller
    controller = Controller(Fr,
                            Solid,
                            Fluid,
                            Injection,
                            simulProp)

    # run the simulation
    controller.run()


####################
# Post_Processing the results #
####################
if post_process:

    myJsonName="./MtoK_closure_w_1_phi_01-Nelts-106-Lr-1-Wmin1-7.570999999999999e-07/simulation__2023-10-11__16_44_14/Exported-results.json"
    from utilities.visualization import *
    from utilities.postprocess_fracture import *

    footprint_times = [1.3, 1.45, 1.6, 1.75, 1.8]

    # # loading simulation results
    # Fr_list, properties = load_fractures(foldername)
    # time_srs = get_fracture_variable(Fr_list,'time')


    Fr_list, properties = load_fractures(address="./MtoK_closure_w_1_phi_01-Nelts-106-Lr-1-Wmin1-7.570999999999999e-07/",
                                         sim_name="simulation__2023-10-11__16_44_14")

    iter = 0
    time_srs = []
    variable_list = []
    for i in Fr_list:
        v_tot = i.mesh.EltArea * (np.sum(i.w[i.EltTip] * i.FillF) + np.sum(i.w[i.EltChannel]))
        iter = iter + 1
        variable_list.append(v_tot / i.injectedVol)
        time_srs.append(i.time)

    eta = [time_srs, variable_list]

    d_list = get_fracture_variable(Fr_list, variable='d_mean', return_time=False)
    # net pressure at inlet
    p_list=get_fracture_variable(Fr_list, variable='net pressure', return_time=False)
    p_inlet=[p_list[e][properties[2].sourceElem[0]] for e in range(len(p_list))]

    # opening at inlet
    w_list = get_fracture_variable(Fr_list, variable='width', return_time=False)
    w_inlet = [w_list[e][properties[2].sourceElem[0]] for e in range(len(w_list))]

    # need to find t_s and get R_s !
    i_ts = np.where(np.asarray(time_srs) == properties[2].injectionRate[0, 1])[0][0]
    R_s = d_list[i_ts]
    R_a = np.asarray(d_list).max()
    R_list = np.asarray(d_list) / R_s
    time_srs_a = np.asarray(time_srs)

    # post-process of closure radius
    # injection point is at 0,0 here
    min_rc = np.zeros(len(Fr_list))
    mean_rc = np.zeros(len(Fr_list))
    for i in range(len(Fr_list)):
        test = Fr_list[i]
        closed_elt = test.closed
        me = test.mesh
        c_c = me.CenterCoor[closed_elt]  # cell_center
        r_c = np.asarray([np.linalg.norm(c_c[e]) for e in range(len(c_c))])
        if r_c.size > 0:
            min_rc[i] = r_c.min() #/ R_s
            mean_rc[i] = np.mean(np.extract(r_c <= min_rc[i] + me.cellDiag / 3., r_c)) #/R_s

    ext_pnts = np.empty((2, 2), dtype=np.float64)
    fracture_list_slice_w_h = plot_fracture_list_slice(Fr_list,
                                                   variable='w',
                                                   projection='2D',
                                                   plot_cell_center=True,
                                                   extreme_points=ext_pnts,
                                                   orientation='horizontal',
                                                   point1=[-Lr, 0.0],
                                                   point2=[Lr, 0.], export2Json=True)

    fracture_list_slice_w_v = plot_fracture_list_slice(Fr_list,
                                                   variable='w',
                                                   projection='2D',
                                                   plot_cell_center=True,
                                                   extreme_points=ext_pnts,
                                                   orientation='vertical',
                                                   point1=[ 0.0,-Lr],
                                                   point2=[0.,Lr], export2Json=True)
    fracture_list_slice_p_h = plot_fracture_list_slice(Fr_list,
                                                   variable='pn',
                                                   projection='2D',
                                                   plot_cell_center=True,
                                                   extreme_points=ext_pnts,
                                                   orientation='horizontal',
                                                   point1=[-Lr, 0.0],
                                                   point2=[Lr, 0.], export2Json=True)

    fracture_list_slice_p_v = plot_fracture_list_slice(Fr_list,
                                                   variable='pn',
                                                   projection='2D',
                                                   plot_cell_center=True,
                                                   extreme_points=ext_pnts,
                                                   orientation='vertical',
                                                   point1=[0.0,-Lr],
                                                   point2=[0.,Lr], export2Json=True)

    if export_results :
        from utilities.postprocess_fracture import append_to_json_file

        append_to_json_file(myJsonName, time_srs, 'append2keyASnewlist', key='Time',
                            delete_existing_filename=True)

        append_to_json_file(myJsonName, eta, 'append2keyASnewlist', key='Efficiency')

        append_to_json_file(myJsonName, d_list, 'append2keyASnewlist', key='Radius')

        append_to_json_file(myJsonName, p_inlet, 'append2keyASnewlist', key='p inlet')

        append_to_json_file(myJsonName, w_inlet, 'append2keyASnewlist', key='w inlet')

        append_to_json_file(myJsonName, min_rc.tolist(), 'append2keyASnewlist', key='Closure Radius min')

        append_to_json_file(myJsonName, mean_rc.tolist(), 'append2keyASnewlist', key='Closure Radius mean')

        append_to_json_file(myJsonName, properties[0].wc, 'append2keyASnewlist', key='wc')

        append_to_json_file(myJsonName, Fr_list[-1].mesh.hx, 'append2keyASnewlist', key='h')

        towrite = {'intersectionHslice width': fracture_list_slice_w_h}
        append_to_json_file(myJsonName, towrite, 'extend_dictionary')
        towrite = {'intersectionVslice width': fracture_list_slice_w_v}
        append_to_json_file(myJsonName, towrite, 'extend_dictionary')
        towrite = {'intersectionHslice p': fracture_list_slice_p_h}
        append_to_json_file(myJsonName, towrite, 'extend_dictionary')
        towrite = {'intersectionVslice p': fracture_list_slice_p_v}
        append_to_json_file(myJsonName, towrite, 'extend_dictionary')



        # --*-- As not to reload find the indicies corresponding to these times --*-- #
        idx_fractures = init_list_of_objects(len(footprint_times))  # indexes
        Fr_list_fractures = init_list_of_objects(len(footprint_times))  # list of the full fractures

        # -- Loop to get the indices, fractures and meshes-- #
        for i in range(len(Fr_list_fractures)):
            idx_fractures[i] = int((np.abs(np.asarray(time_srs) - footprint_times[i])).argmin())
            Fr_list_fractures[i] = copy.deepcopy(Fr_list[idx_fractures[i]])
            if isinstance(Fr_list_fractures[i].mesh, int):
                Fr_list_fractures[i].mesh = Fr_list[Fr_list_fractures[i].mesh].mesh

        # --*-- We get again the exact time (of the closest time-step) --*-- #
        time_srs_full = get_fracture_variable(Fr_list_fractures, variable='time')

        # --*-- Get footprints width and pressure for full 3D --*-- #
        # -- Get the relevant lists -- #
        pressure = get_fracture_variable(Fr_list_fractures, variable='pn')
        width = get_fracture_variable(Fr_list_fractures, variable='w')

        # get information on the mesh
        mesh_list = []

        # get information on closed cells
        closed_cells = init_list_of_objects(len(Fr_list_fractures))

        # --*-- We export the data of the footprints --*-- #
        fp_list = get_fracture_fp(Fr_list_fractures)

        for num, i in enumerate(Fr_list_fractures):
            export_mesh = np.array([i.mesh.NumberOfElts])
            export_mesh = np.append(export_mesh, [i.mesh.hx, i.mesh.hy, i.mesh.nx, i.mesh.ny])
            export_mesh = np.append(export_mesh, i.mesh.Connectivity.flatten())
            export_mesh = np.append(export_mesh, i.mesh.VertexCoor.flatten())
            export_mesh = np.append(export_mesh, i.mesh.CenterCoor.flatten())
            mesh_list.append(list(export_mesh))
            if len(i.closed) != 0:
                closed_cells[num] = list(i.closed.astype(float))
            else:
                closed_cells[num] = [-1.]

        # --*-- Define the name of the file and where to save it --*-- #
        # Note: please do not change the string in the end (Mathematica code is based on it!)
        full_Name = myJsonName.split(".json")[0] + '_fractures.json'

        # First we append the time
        append_to_json_file(full_Name, time_srs_full,
                            'append2keyASnewlist',
                            key='time',
                            delete_existing_filename=True)

        # add the width
        towrite = {'w': [widthi.tolist() for widthi in width]}
        append_to_json_file(full_Name, towrite, 'extend_dictionary')
        # add the pressure
        towrite = {'pn': [pressurei.tolist() for pressurei in pressure]}
        append_to_json_file(full_Name, towrite, 'extend_dictionary')
        # add the mesh data
        towrite = {'mesh_info': mesh_list}
        append_to_json_file(full_Name, towrite, 'extend_dictionary')
        # add the mesh data
        towrite = {'closed': closed_cells}
        append_to_json_file(full_Name, towrite, 'extend_dictionary')
        # -- Footprints -- #
        towrite = {'fp': [[fp[:, 0].tolist(), fp[:, 1].tolist()] for fp in fp_list]}
        append_to_json_file(full_Name, towrite, 'extend_dictionary')

        del width, pressure

    if plotting :
     # plotting efficiency
        plot_prop = PlotProperties(#graph_scaling='loglog',
                               line_style='.')
        label = LabelProperties('efficiency')
        label.legend = 'fracturing efficiency'
        Fig_eff = plot_fracture_list(Fr_list,
                               variable='efficiency',
                               plot_prop=plot_prop,
                               labels=label)
        plt.show(block=True)

        label = LabelProperties('d_mean')
        label.legend = 'radius'
        Fig_r = plot_fracture_list(Fr_list,
                                   variable='d_mean',
                                   plot_prop=plot_prop,
                                   labels=label)
        plt.show(block=True)

        import matplotlib.pyplot as plt
        fig, ax = plt.subplots()
        ax.plot(time_srs,R_list,'r')
        ax.plot(time_srs[i_ts:],mean_rc[i_ts:],'b')
        ax.plot(time_srs[i_ts:],min_rc[i_ts:],'.b')
        plt.xlabel("Time (s)")
        plt.ylabel("R/R_s (m)")
        #ax.legend(["analytical solution","numerics"])
        plt.show()


        fig, ax = plt.subplots()
        ax.plot(time_srs[0:],p_inlet[0:],'r')
        plt.xlabel("Time (s)")
        plt.ylabel("p(0,t) (Pa)")
    #ax.legend(["analytical solution","numerics"])
        plt.show()