import dolfin as dl
import numpy as np
import random
import pandas as pd
import sys
import os
import timeit
import cpp_codes
import cceua
import forward_modeling
from openpyxl import Workbook


cceua = cceua.cceua
forward_model = forward_modeling.forward_model

run_name = sys.argv[1] # any string
dataset_index = int(sys.argv[2]) # from 0 to 6
hydraulic_model = sys.argv[3] # "VGM" or "PDI"
scale_index = int(sys.argv[4]) # from 0 to 3; select 1 (dz = 0.25 cm)
data_type = sys.argv[5] # "synthetic_noise" or "synthetic_true" or "point" or "horizontal"
weight_parameter = float(sys.argv[6])

# import hydraulic parameters from the fitting to measured water retention curve
df_WRC = pd.read_csv(f"../data/WRC_fitted_summary.csv")
soil_data = pd.read_excel(f"../data/soil_properties.xlsx")

dataset_list = soil_data["Soil"].tolist()
K_s_list = soil_data["K_s_cm_s-1"].tolist()
time_interval_list = soil_data["time_inverval_min"].tolist()
bulk_density_list = soil_data["Established_Bulk_Density_g_cm-3"].tolist()
bulk_density_WRC_list = soil_data["bulk_density_WRC_measurement_g_cm-3"].tolist()
particle_density_list = soil_data["Particle_Density_g_cm-3"].tolist()

dataset = dataset_list[dataset_index]
K_s = K_s_list[dataset_index]
time_interval = time_interval_list[dataset_index]
bulk_density = bulk_density_list[dataset_index]
bulk_density_WRC = bulk_density_WRC_list[dataset_index]
particle_density = particle_density_list[dataset_index]

porosity = 1.0 - bulk_density/particle_density
porosity_WRC = 1.0 - bulk_density_WRC/particle_density

print(f"Run_name is {run_name}, dataset is {dataset}, hydraulic_model is {hydraulic_model}.")
print(f"scale_index is {scale_index}, data_type is {data_type}.")

df_WRC_2 = df_WRC[(df_WRC["dataset"] == dataset) & (df_WRC["hydraulic_model"] == hydraulic_model)]


if data_type in ["synthetic_noise", "synthetic_true"]:
    T_max = 220
    theta_s_lower_bound = 0.34
    theta_s_upper_bound = 0.44
    if hydraulic_model == "VGM":
        theta_r_lower_bound = 0.0
        theta_r_upper_bound = 0.1
    if hydraulic_model == "PDI":
        theta_r_lower_bound = 0.0
        theta_r_upper_bound = theta_s_lower_bound
elif data_type == "point":
    T_max_list = [10.5, 16.0, 220, 360, 200, 110, 750]
    T_max = T_max_list[dataset_index]
    # this threshold changes for soils and scale (range is 0.10 except for AZ15)
    theta_s_lower_bound_list = [0.24, 0.31, 0.32, 0.35, 0.33, 0.40, 0.35]
    theta_s_upper_bound_list = [0.34, 0.41, 0.42, 0.45, 0.43, 0.65, 0.45]
    theta_s_lower_bound = theta_s_lower_bound_list[dataset_index]
    theta_s_upper_bound = theta_s_upper_bound_list[dataset_index]

    if hydraulic_model == "VGM":
        theta_r_lower_bound = 0.0
        theta_r_upper_bound = 0.05

    if hydraulic_model == "PDI":
        theta_r_lower_bound = 0.0
        # theta_r_lower_bound = 0.5*df_WRC_2.iloc[:, 3].values[0]*(porosity/porosity_WRC)
        theta_r_upper_bound = theta_s_lower_bound

elif data_type == "horizontal":
    T_max_list = [10.5, 16.0, 220, 360, 200, 110, 750]
    T_max = T_max_list[dataset_index]
    # this threshold changes for soils and scale (range is 0.10 except for AZ15)
    theta_s_lower_bound_list = [0.24, 0.31, 0.32, 0.35, 0.33, 0.40, 0.35]
    theta_s_upper_bound_list = [0.34, 0.41, 0.42, 0.45, 0.43, 0.65, 0.45]
    theta_s_lower_bound = theta_s_lower_bound_list[dataset_index]
    theta_s_upper_bound = theta_s_upper_bound_list[dataset_index]

    if hydraulic_model == "VGM":
        theta_r_lower_bound = 0.0
        theta_r_upper_bound = 0.05
    if hydraulic_model == "PDI":
        theta_r_lower_bound = 0.0 
        # theta_r_lower_bound = 0.5*df_WRC_2.iloc[:, 3].values[0]*(porosity/porosity_WRC)
        theta_r_upper_bound = theta_s_lower_bound


depth_list = [-8.5, -7.5, -6.5, -5.5, -4.5, -3.5, -2.5, -1.5]


# import data

if data_type == "synthetic_true":
    flux_data = np.load(f"../data/experimental_data/synthetic/true_flux_data_{hydraulic_model}.npy")
    theta_data = np.load(f"../data/experimental_data/synthetic/true_theta_data_{hydraulic_model}.npy")
    time_data = np.load(f"../data/experimental_data/synthetic/time_data_{hydraulic_model}.npy")
elif data_type == "synthetic_noise":
    flux_data = np.load(f"../data/experimental_data/synthetic/noisy_flux_data_{hydraulic_model}.npy")
    theta_data = np.load(f"../data/experimental_data/synthetic/noisy_theta_data_{hydraulic_model}.npy")
    time_data = np.load(f"../data/experimental_data/synthetic/time_data_{hydraulic_model}.npy")
elif data_type == "point":
    theta_data = np.load(f"../data/experimental_data/{dataset}/theta_SWIR_point.npy")[scale_index, 0:int(T_max/time_interval)+1, 0:8]
    flux_data = np.load(f"../data/experimental_data/{dataset}/flux_data.npy")[0:int(T_max/time_interval)+1]
elif data_type == "horizontal":
    theta_data = np.load(f"../data/experimental_data/{dataset}/theta_SWIR_1D.npy")[scale_index, 0:int(T_max/time_interval)+1, 0:8]
    flux_data = np.load(f"../data/experimental_data/{dataset}/flux_data.npy")[0:int(T_max/time_interval)+1]

theta_data = theta_data.flatten()


 ## Define objective function

initial_parameters = []
if hydraulic_model == "VGM":
    # the parameters are theta_r, theta_s, alpha, n, K_s
    # alpha, n, and K_s are transformed as below:
    # alpha_t = log10(1/alpha)
    # n_t = log10(n)
    # K_s_t = log10(K_s)
    initial_parameters.append(df_WRC_2.iloc[:, 3].values[0]) # theta_r
    initial_parameters.append(df_WRC_2.iloc[:, 4].values[0]) # theta_s
    initial_parameters.append(np.log10(1/(2*df_WRC_2.iloc[:, 5].values[0]))) # alpha
    initial_parameters.append(np.log10(df_WRC_2.iloc[:, 6].values[0])) # n
    initial_parameters.append(np.log10(K_s/10)) # K_s; devided by 10 because K_s is very high for the experiment

    x_0 = np.array(initial_parameters)

    lower_bound = np.array([theta_r_lower_bound, theta_s_lower_bound, np.log10(1/1.0), np.log10(1.0001), np.log10(K_s/(10**4))])
    upper_bound = np.array([theta_r_upper_bound, theta_s_upper_bound, np.log10(1/0.00001), np.log10(7.5), np.log10(K_s*(10**2))])

if hydraulic_model == "PDI":
    # the parameters are theta_r, theta_s, alpha, n, K_s, K_sf
    # alpha, n, and K_s are transformed as below:
    # alpha_t = log10(1/alpha)
    # n_t = log10(n)
    # K_s_t = log10(K_s)
    # K_sf_t = log10(K_sf)
    initial_parameters.append(df_WRC_2.iloc[:, 3].values[0]) # theta_r
    initial_parameters.append(df_WRC_2.iloc[:, 4].values[0]) # theta_s
    initial_parameters.append(np.log10(1/(2*df_WRC_2.iloc[:, 5].values[0]))) # alpha
    initial_parameters.append(np.log10(df_WRC_2.iloc[:, 6].values[0])) # n
    initial_parameters.append(np.log10(K_s/10)) # K_s; devided by 10 because K_s is very high for the experiment
    initial_parameters.append(np.log10(K_s/10**4)) # K_sf

    x_0 = np.array(initial_parameters)

    lower_bound = np.array([theta_r_lower_bound, theta_s_lower_bound, np.log10(1/1.0), np.log10(1.0001), np.log10(K_s/(10**4)), np.log10(K_s/10**8)])
    upper_bound = np.array([theta_r_upper_bound, theta_s_upper_bound, np.log10(1/0.00001), np.log10(7.5), np.log10(K_s*(10**2)), np.log10(K_s*10)])


# record result
wb = Workbook()
ws = wb.active

ws['A1'] = "run_name"
ws['B1'] = "dataset" # dataset name
ws['C1'] = "hydraulic_model"
ws['D1'] = "data_type"
ws['E1'] = "scale_index"
ws['F1'] = "ID"
ws['G1'] = "weight_parameter"

# other info
ws['H1'] = "T_max"
ws['I1'] = "time_interval"
ws['J1'] = "simulation_time"
ws['K1'] = "line_search"
ws['L1'] = "Newton"
ws['M1'] = "water_reach"
ws['N1'] = "theta_error"
ws['O1'] = "flux_error"
ws['P1'] = "cost"

# parameters

ws['Q1'] = "theta_r"
ws['R1'] = "theta_s"
ws['S1'] = "alpha"
ws['T1'] = "n"
ws['U1'] = "K_s"

if hydraulic_model == "PDI":
    ws['V1'] = "K_sf"

ws['W1'] = "theta_r_bounds"
ws['W2'] = lower_bound[0]
ws['W3'] = upper_bound[0]
ws['X1'] = "theta_s_bounds"
ws['X2'] = lower_bound[1]
ws['X3'] = upper_bound[1]
ws['Y1'] = "alpha_bounds"
ws['Y2'] = 1.0/(10**upper_bound[2])
ws['Y3'] = 1.0/(10**lower_bound[2])
ws['Z1'] = "n_bounds"
ws['Z2'] = 10**lower_bound[3]
ws['Z3'] = 10**upper_bound[3]
ws['AA1'] = "K_s_bounds"
ws['AA2'] = 10**lower_bound[4]
ws['AA3'] = 10**upper_bound[4]
if hydraulic_model == "PDI":
    ws['AB1'] = "Ksnc_bounds"
    ws['AB2'] = 10**lower_bound[5]
    ws['AB3'] = 10**upper_bound[5]



if not os.path.isdir(f'../data/inverse_modeling/{hydraulic_model}'):
    os.mkdir(f'../data/inverse_modeling/{hydraulic_model}')

if not os.path.isdir(f'../data/inverse_modeling/{hydraulic_model}/{dataset}'):
    os.mkdir(f'../data/inverse_modeling/{hydraulic_model}/{dataset}')

if not os.path.isdir(f'../data/inverse_modeling/{hydraulic_model}/{dataset}/{run_name}'):
    os.mkdir(f'../data/inverse_modeling/{hydraulic_model}/{dataset}/{run_name}')

def objective_function(parameters, ID):
    ID = ID + 1

    hydraulic_parameters = parameters.copy()
    # transformation
    if hydraulic_model == "VGM":
        hydraulic_parameters[2] = 1.0/(10**parameters[2])
        hydraulic_parameters[3] = 10**parameters[3]
        hydraulic_parameters[4] = 10**parameters[4]

    if hydraulic_model == "PDI":
        hydraulic_parameters[2] = 1.0/(10**parameters[2])
        hydraulic_parameters[3] = 10**parameters[3]
        hydraulic_parameters[4] = 10**parameters[4]
        hydraulic_parameters[5] = 10**parameters[5]

    print("ID is ", ID, "; Hydraulic parameters are ", ["{0:0.32f}".format(i) for i in hydraulic_parameters])

    epsilon_theta = 1.0 # standard deviation of the error is unknown
    epsilon_flux = 1.0 # standard deviation of the error is unknown

    start_time = timeit.default_timer()
    estimated_theta, estimated_flux = forward_model(hydraulic_parameters, hydraulic_model, T_max, time_interval, depth_list)
    time_evaluate = timeit.default_timer() - start_time

    # this cost function is the same as the sum of the squares of the normalized residual + regularization on n

    theta_error = (1.0/len(theta_data))*np.sum(((estimated_theta - theta_data)/np.median(theta_data))**2)
    flux_error = (1.0/len(flux_data))*np.sum(((estimated_flux - flux_data)/np.median(flux_data))**2)
#     regularization = (np.log10(hydraulic_parameters[2]) - np.log10(df_WRC_2.iloc[:, 6].values[0]))**2 # prior is n value from the lab water retention curve
    cost = theta_error + weight_parameter*flux_error

    # record result

    row = ID + 1

    ws[f'A{row}'] = run_name
    ws[f'B{row}'] = dataset # dataset name
    ws[f'C{row}'] = hydraulic_model
    ws[f'D{row}'] = data_type
    ws[f'E{row}'] = scale_index
    ws[f'F{row}'] = ID
    ws[f'G{row}'] = weight_parameter

    # other info
    ws[f'H{row}'] = T_max
    ws[f'I{row}'] = time_interval
    ws[f'J{row}'] = time_evaluate
    print("Mean of theta is ", np.mean(estimated_theta))
    if np.mean(estimated_theta) == 1000.0:
        ws[f'K{row}'] = "Failed"
    else:
        ws[f'K{row}'] = "Success"

    if np.mean(estimated_theta) == 2000.0:
        ws[f'L{row}'] = "Failed"
    else:
        ws[f'L{row}'] = "Success"

    if np.mean(estimated_theta) == 3000.0:
        ws[f'M{row}'] = "Reached"
    else:
        ws[f'M{row}'] = "Not_Reached"

    ws[f'N{row}'] = theta_error
    ws[f'O{row}'] = flux_error
    ws[f'P{row}'] = cost

    # parameters

    ws[f'Q{row}'] = hydraulic_parameters[0]
    ws[f'R{row}'] = hydraulic_parameters[1]
    ws[f'S{row}'] = hydraulic_parameters[2]
    ws[f'T{row}'] = hydraulic_parameters[3]
    ws[f'U{row}'] = hydraulic_parameters[4]

    if hydraulic_model == "PDI":
        ws[f'V{row}'] = hydraulic_parameters[5]

    wb.save(f'../data/inverse_modeling/{hydraulic_model}/{dataset}/{run_name}/{run_name}_global_report.xlsx')

    return cost, ID


# SCE-UA parameters (see Duan, 1994)
# the code was based on Python Spotpy package with a few modifications (https://github.com/thouska/spotpy)

# Step 0: Define parameters
random_seed = 0
flag_0 = 0 # if this value is 1, the initial guess is included in the initial samples
p = 4 # number of complexes. This depends on the difficulty of the optimization problem.
n = x_0.shape[0] # number of parameters to be optimized
m = 2 * n + 1 # number of points for each complex
q = n + 1 # number of points in a subcomplex (the subcompex becomes a simplex)
alpha = 1 # number of consecutive offspring generated by each subcomplex (this is not used)
beta = 2 * n + 1  # number of evolution steps taken by each complex
p_min = p # minimum number of complexes required in the pupulation
s = p * m # the number of sample points

bound = upper_bound - lower_bound

# convergence criteria
max_evaluations = 10000
stop_evolution = 10 # minimum times of evolution, after this times, the convergency criteria starts to be evaluated
epsilon = 0.0001 # prescribed small parameter space
epsilon2 = 0.1 # prescribed small change in convergence criteria
# random seed
np.random.seed(random_seed)

# Step 1: Generate initial sample
x = np.zeros((s, n))
x = lower_bound + np.random.rand(s, n) * bound

if flag_0 == 1:
    x[0,:] = x_0

# Step 2: Rank points
number_call = 0
ID = 0
f_x = np.zeros((s, 1))
for i in range(s):
    f_x[i, :], ID = objective_function(x[i, :], ID) # the objective function should be evaluated sequently (not vectorized)
    number_call += 1
    print(f"{i + 1}th sample was computed.")

# the last column is f(x)
X = np.hstack((x, f_x))
# sort the population in order of inreasing function values
X = X[np.argsort(X[:,n])]

x_best = X[0, 0:-1]
f_best = X[0, n]

x_worst = X[-1, 0:-1]
f_worst = X[-1, n]

# compute the standard deviation for each parameter
x_n_sigma = np.std(X[:, 0:-1], axis = 0)
# relative range with respect to the bound of the parameters
range_n = (np.amax(X[:, 0:-1], axis = 0) - np.amin(X[:, 0:-1], axis = 0))/bound
# geometric mean of the relative range
geo_range_n = np.exp(np.mean(np.log(range_n)))

print("The initial loop: 0")
print(f"The initial guess of x is {x_0}.")
print(f"The best f is {f_best}.")
print(f"The best x is {x_best}.")
print(f"The worst f is {f_worst}.")
print(f"The worst x is {x_worst}.")

# check the convergency (see Torczon_1989 for other convergence criteria)
if number_call > max_evaluations:
    print(f"Optimization search terminated because the limit")
    print(f"on the maximum number of function evaluations {max_evaluations}.")
    print(f"The search was stopped at trial number {number_call} of the initial loop.")

if geo_range_n < epsilon:
    print("The population has converged to a prescribed small parameter space.")


# begin evolution loops
number_loop = 0
criteria = []
criteria_change = 1e+5
proceed = True

while number_call < max_evaluations and geo_range_n > epsilon and criteria_change > epsilon2 and proceed == True:
    number_loop += 1

    for i in range(p): # loop for each complex (m points for each complex)
        # Step 3: Partition into complexes
        # partition the pupulation into compexes (sub-population)
        x_complex = X[i::p, 0:n]
        f_complex = X[i::p, n][:, None]


        # Step 4 : Evolve each complex
        # evolve the subpopulation (complex) using CCE algorithm (beta times)
        for loop in range(beta):

            # Select q points to define a simplex by sampling the complex according to a linear probability distribution
            indexes_sampled = np.array([0] * q)
            indexes_sampled[0] = 0 # always pick the best one
            for j in range(1, q):

                for k in range(1000):
                    # the equation below is a root of a quadratic equation with respect to index
                    # given a cumulative probability
                    index_sampled = int(np.floor(
                                        m + 0.5 - np.sqrt((m + 0.5)**2 - m * (m + 1) * np.random.random())))
                    # check if the element has already been chosen
            #         idx = (indexes_sampled[0:j] == index_sampled).nonzero()   # used in the original code
                    flag = (indexes_sampled[0:j] == index_sampled).sum()
            #         if idx[0].size == 0:   # used in the original code
                    if flag == 0:
                        break
                indexes_sampled[j] = index_sampled

            indexes_sampled.sort()

            # construct the simplex
            x_sampled = x_complex[indexes_sampled, :]
            f_sampled = f_complex[indexes_sampled, :]

            # run the modified NeNelder and Mead simplex method (step II to VI in the CCE algorithm)
            x_sampled_new, f_sampled_new, number_call, ID = cceua(x_sampled, f_sampled, objective_function, lower_bound, upper_bound, number_call, ID)

            # replace the worst point
            x_sampled[-1,:] = x_sampled_new
            f_sampled[-1,:] = f_sampled_new

            # put the updated simplex back to the complex (sub-population)
            x_complex[indexes_sampled, :] = x_sampled
            f_complex[indexes_sampled, :] = f_sampled

            # sort the complex
            idx = np.argsort(f_complex.flatten())
            x_complex = x_complex[idx, :]

            f_complex = np.sort(f_complex, axis = 0)

        X[i::p, 0:n] = x_complex
        X[i::p, n] = f_complex.flatten()

    # shuffled the complexes
    X = X[np.argsort(X[:,n])]
    x_evolved = X[:, 0:n]
    f_evolved = X[:, n]

    # record the best and worst points
    x_best = X[0, 0:-1]
    f_best = X[0, n]

    x_worst = X[-1, 0:-1]
    f_worst = X[-1, n]

    # compute the standard deviation for each parameter
    x_n_sigma = np.std(X[:, 0:-1], axis = 0)
    # relative range with respect to the bound of the parameters
    range_n = (np.amax(X[:, 0:-1], axis = 0) - np.amin(X[:, 0:-1], axis = 0))/bound
    # geometric mean of the relative range
    geo_range_n = np.exp(np.mean(np.log(range_n)))

    criteria.append(f_best)

    print(f"Evolution loop: {number_loop}; Function call: {number_call}")
    print(f"The best f is {f_best}.")
    print(f"The best x is {x_best}.")
    print(f"The worst f is {f_worst}.")
    print(f"The worst x is {x_worst}.")

    # check the convergency (see Torczon_1989 for other convergence criteria)
    if number_call > max_evaluations:
        print(f"Optimization search terminated because the limit")
        print(f"on the maximum number of function evaluations {max_evaluations}.")

    if geo_range_n < epsilon:
        print("The population has converged to a prescribed small parameter space.")

    if number_loop >= stop_evolution:
        print("Objective function convergence criteria is now being updated.")
        absolute_change = np.abs(criteria[number_loop - 1] - criteria[number_loop - stop_evolution])*100
        denominator = np.mean(np.abs(criteria[(number_loop - stop_evolution):number_loop]))
        if denominator == 0.0:
            criteria_change = 0.0
        criteria_change = absolute_change / denominator

        print("Updated convergence criteria: %f" % criteria_change)
        if criteria_change <= epsilon2:
            print("The best point has improved in last %d loops by less than the user-specified threshold %f"
                 %(stop_evolution, criteria_change))
            print("Convergence has acheved based on objective function criteria!")



# end of the outer loops
print("Search was stopped at trial number: %d" % number_call)
print("Normalized geometric range: %f" % geo_range_n)
print("The best point has improved in the last %d loops by %f percent" % (stop_evolution, criteria_change))
