import pickle
import os
import time
import numpy as np
import scipy.integrate
import scipy.sparse.linalg
from compute_probabilities import compute_leverage_scores, compute_frob_prob
from random_basis_generation import random_basis_generation, random_basis_generation_split
from euler_method import euler_method_dense

### This file runs iterations of the randomized basis generation and saves the results, i.e.:
### In each run it constructs a reduced model and computes a reduced solution based on solutions of multiple transfer
### operators with random initial conditions and starting/ending points in time that were chosen according to a
### probability distribution (i.e.: uniform, squared norms, or leverage scores).


advec_x = 0.3
diff=0.01
nt = 500

test_problem = f'advec_x={advec_x}_diff={diff}_nt={nt}'

# Load test problem, necessary FE data and high-fidelity solution:
FE_results = pickle.load(open('results/FE_data_problem='+test_problem+'.pickle','rb'))
FE_solution = FE_results['FE_solution']
mass_bc = FE_results['mass_bc']
stiff_bc = FE_results['stiff_bc']
advec_bc = FE_results['advec_bc']
stiffs_coeff_in_time_bc = FE_results['stiffs_coeff_in_time_bc']
rhs_matrix = FE_results['rhs_matrix']
alpha_matrix = FE_results['alpha_matrix']
grid_t = FE_results['grid_t']
ht = grid_t[1]-grid_t[0]
T_start = grid_t[0]
T_finish = grid_t[-1]
nt = len(grid_t) - 1
u_0 = FE_results['u0_discrete']

# Precomputations:

# Adjusted for non-varying coefficient functions:
stiff_solve = scipy.sparse.linalg.factorized((stiffs_coeff_in_time_bc + advec_bc).tocsc())
matrix_left_temp = mass_bc + ht * (stiffs_coeff_in_time_bc + advec_bc)
lhs_solve = scipy.sparse.linalg.factorized(matrix_left_temp.tocsc())

# Construct reduced (random) basis  via  uniform sampling for constant advection and
# Frobenius norm (=rank 1 leverage scores in this case) sampling for right hand side

prob = compute_frob_prob(rhs_matrix)   # equals rank-1 leverage scores in this case!
prob_2 = None

# Create directory for saving results:
dirname = 'results_problem='+test_problem
os.makedirs(dirname,exist_ok=True)

# Choose if you want to draw from 2 probability distributions simultaneously:
two_probs = True
# Choose number of random time points drawn from probability distribution (= number of random initial conditions):
n_rand = 10
n_rand_2_list = [0,5,10,15,20,25,30,35,40]
# Choose number of time steps for (local) computations and time step for starting collecting snapshots:
nt_rand = 30
k = 24
# Choose number of time steps (including t=0) that are computed and included for representation of initial data u_0:
nt_init = nt_rand
# Specify a tolerance for determining number of reduced (random) basis functions:
tol = 10**(-8)
# Determine number of (random) iterations:
iterations = 10000
# Decide if local computations are splitted into computations for (general f and zero u_0) and (f=0 and random u_0),
# and if yes, determine how many random initial conditions should be included for each local computation:
split = False
n_initials = 1

for run in range(9):
    n_rand_2 = n_rand_2_list[run]

    # Generate random generator with specified seed:
    random_seed = 0
    rng = np.random.default_rng(seed = random_seed)

    # For saving results:
    results = {}
    rel_L2_errors_over_time = np.zeros((iterations, len(grid_t)))
    rel_L2_errors_max_time = np.zeros(iterations)
    rel_C0_L2_errors = np.zeros(iterations)
    rel_L2_H1_errors = np.zeros(iterations)
    red_bases_sizes = np.zeros(iterations)
    if two_probs == True:
        chosen_time_points = np.zeros((iterations,n_rand+n_rand_2))
    else:
        chosen_time_points = np.zeros((iterations,n_rand))

    # Run iterations:
    tic = time.time()
    for iter in range(iterations):
        if (iter+1)%10 == 0:
            time_temp = time.time()
            print(f'Performing iteration: {iter+1}. Run time: {(time_temp-tic)/60} minutes.')

        # Compute reduced basis:
        if split == True:
            if two_probs == True:
                rand_time_points, singular_vals_rand, red_basis_rand = random_basis_generation_split(rng, prob, mass_bc, lhs_solve, stiff_solve, grid_t, rhs_matrix, u_0, n_rand, n_initials, k, nt_rand, nt_init, tol, two_probs,prob_2,n_rand_2)
            else:
                rand_time_points, _, red_basis_rand = random_basis_generation_split(rng, prob,mass_bc,lhs_solve,stiff_solve,grid_t,rhs_matrix,u_0,n_rand,n_initials,k,nt_rand,nt_init,tol)
        else:
            if two_probs == True:
                rand_time_points, singular_vals_rand, red_basis_rand = random_basis_generation(rng, prob, mass_bc, lhs_solve, stiff_solve, grid_t, rhs_matrix, u_0, n_rand, k, nt_rand, nt_init, tol, two_probs,prob_2,n_rand_2)
            else:
                rand_time_points, _, red_basis_rand = random_basis_generation(rng, prob, mass_bc, lhs_solve, stiff_solve, grid_t, rhs_matrix, u_0, n_rand, k, nt_rand, nt_init, tol)

        # Construct reduced model:
        mass_red = red_basis_rand.T.dot(mass_bc.dot(red_basis_rand))
        # Adjust for non-varying coefficients:
        stiff_red = red_basis_rand.T.dot((stiffs_coeff_in_time_bc + advec_bc).dot(red_basis_rand))
        rhs_matrix_red = red_basis_rand.T.dot(rhs_matrix)
        u0_red = np.linalg.solve(red_basis_rand.T.dot(mass_bc.dot(red_basis_rand)),red_basis_rand.T.dot(mass_bc.dot(u_0)))
        # Compute reduced solution for full time interval with reduced model and implicit euler:
        red_solution = euler_method_dense(u0_red, T_start, T_finish, nt, mass_red, stiff_red, rhs_matrix_red)
        # Generate function in high dimensional FE space to compare with high-fidelity FE solution:
        red_solution = red_basis_rand.dot(red_solution)

        # L2 and H1 errors in time:
        L2_errors_over_time = np.zeros(len(grid_t))
        L2_norms_FE = np.zeros(len(grid_t))
        H1_errors_over_time = np.zeros(len(grid_t))
        H1_norms_FE = np.zeros(len(grid_t))
        for j in range(len(grid_t)):
            error_temp = FE_solution[:, j] - red_solution[:,j]
            L2_errors_over_time[j] = np.sqrt(error_temp.T.dot(mass_bc.dot(error_temp)))
            L2_norms_FE[j] = np.sqrt(FE_solution[:,j].T.dot(mass_bc.dot(FE_solution[:,j])))
            H1_errors_over_time[j] = np.sqrt(error_temp.T.dot(stiff_bc.dot(error_temp)))
            H1_norms_FE[j] = np.sqrt(FE_solution[:,j].T.dot(stiff_bc.dot(FE_solution[:,j])))

        # L2 error behavior in time and C0(L2)-error:
        # Avoid that an divison by zero error occurs (for instance if u_0 = 0)
        L2_errors_over_time_relative = np.divide(L2_errors_over_time, L2_norms_FE, out=np.zeros_like(L2_errors_over_time), where=L2_norms_FE != 0)
        max_relative_L2_error_in_time = np.max(L2_errors_over_time_relative)
        relative_C0_L2_error = np.max(L2_errors_over_time) / np.max(L2_norms_FE)

        # L2(H1) error:
        # Define squared error as linearly interpolated function in time, compute integral and take square root:
        def H1_error_squared(t): return np.interp(t,grid_t,H1_errors_over_time)**2
        def H1_norm_squared(t): return np.interp(t,grid_t,H1_norms_FE)**2
        L2_H1_error_absolute = np.sqrt(scipy.integrate.quad(H1_error_squared, T_start, T_finish, epsabs=1e-12, epsrel=1e-12,full_output=1)[0])
        L2_H1_norm = np.sqrt(scipy.integrate.quad(H1_norm_squared, T_start, T_finish, epsabs=1e-12, epsrel=1e-12,full_output=1)[0])
        L2_H1_error_relative = L2_H1_error_absolute / L2_H1_norm

        # Save results of this run:
        rel_L2_errors_over_time[iter,:] = L2_errors_over_time_relative
        rel_L2_errors_max_time[iter] = max_relative_L2_error_in_time
        rel_C0_L2_errors[iter] = relative_C0_L2_error
        rel_L2_H1_errors[iter] = L2_H1_error_relative
        red_bases_sizes[iter] = red_basis_rand.shape[1]

        if two_probs == True:
            full_n_rand = n_rand + n_rand_2
        else:
            full_n_rand = n_rand
        if len(rand_time_points) == full_n_rand:
            chosen_time_points[iter,:] = rand_time_points
        else:
            chosen_time_points[iter,:] = np.append(rand_time_points,np.zeros(full_n_rand-len(rand_time_points)))


        ######
        if (iter+1)%500 == 0:
            results['rel_L2_errors_over_time'] = rel_L2_errors_over_time[:(iter+1),:]
            results['rel_L2_errors_max_time'] = rel_L2_errors_max_time[:(iter+1)]
            results['rel_C0_L2_errors'] = rel_C0_L2_errors[:(iter+1)]
            results['rel_L2_H1_errors'] = rel_L2_H1_errors[:(iter+1)]
            results['red_bases_sizes'] = red_bases_sizes[:(iter+1)]
            results['chosen_time_points'] = chosen_time_points[:(iter+1),:]

            if split == True:
                if two_probs == True:
                    pickle.dump(results, open(dirname + f'/split_ninitials={n_initials}_iterations={iter+1}_nrand_rhs={n_rand}_nrand_coeff={n_rand_2}_ninit=ntrand={nt_rand}_k={k}_tol={tol}.pickle', 'wb'))
                else:
                    pickle.dump(results, open(dirname + f'/split_ninitials={n_initials}_iterations={iter+1}_nrand={n_rand}_ninit=ntrand={nt_rand}_k={k}_tol={tol}.pickle','wb'))
            else:
                if two_probs == True:
                    pickle.dump(results, open(dirname + f'/iterations={iter+1}_nrand_rhs={n_rand}_n_rand_coeff={n_rand_2}_ninit=ntrand={nt_rand}_k={k}_tol={tol}.pickle', 'wb'))
                else:
                    pickle.dump(results, open(dirname + f'/iterations={iter+1}_nrand={n_rand}_ninit=ntrand={nt_rand}_k={k}_tol={tol}.pickle', 'wb'))

        ######


    # save all collected results:
    results['rel_L2_errors_over_time'] = rel_L2_errors_over_time
    results['rel_L2_errors_max_time'] = rel_L2_errors_max_time
    results['rel_C0_L2_errors'] = rel_C0_L2_errors
    results['rel_L2_H1_errors'] = rel_L2_H1_errors
    results['red_bases_sizes'] = red_bases_sizes
    results['chosen_time_points'] = chosen_time_points
    if split == True:
        if two_probs == True:
            pickle.dump(results, open(dirname+f'/split_ninitials={n_initials}_iterations={iterations}_nrand_rhs={n_rand}_nrand_coeff={n_rand_2}_ninit=ntrand={nt_rand}_k={k}_tol={tol}.pickle','wb'))
        else:
            pickle.dump(results, open(dirname+f'/split_ninitials={n_initials}_iterations={iterations}_nrand={n_rand}_ninit=ntrand={nt_rand}_k={k}_tol={tol}.pickle','wb'))
    else:
        if two_probs == True:
            pickle.dump(results, open(dirname+f'/iterations={iterations}_nrand_rhs={n_rand}_n_rand_coeff={n_rand_2}_ninit=ntrand={nt_rand}_k={k}_tol={tol}.pickle','wb'))
        else:
            pickle.dump(results, open(dirname+f'/iterations={iterations}_nrand={n_rand}_ninit=ntrand={nt_rand}_k={k}_tol={tol}.pickle','wb'))
    toc = time.time()
    print(f'{iterations} iterations done in {(toc-tic)/60} minutes.')