import pickle
import numpy as np
import scipy.sparse.linalg
import time
from common.euler_method import euler_method_sparse
from common.assemble import gram_schmidt_ortho
import matplotlib.pyplot as plt


test_problem = 'stove'

# Load test problem, necessary FE data and high-fidelity solution:
FE_results = pickle.load(open('results/FE_data_problem='+test_problem+'.pickle','rb'))
FE_solution = FE_results['FE_solution']
mass_bc = FE_results['mass_bc']
stiff_bc = FE_results['stiff_bc']
stiffs_coeff_in_time_bc = FE_results['stiffs_coeff_in_time_bc']
rhs_matrix = FE_results['rhs_matrix']
grid_t = FE_results['grid_t']
ht = grid_t[1]-grid_t[0]
T_start = grid_t[0]
T_finish = grid_t[-1]
nt = len(grid_t) - 1
u_0 = FE_results['u0_discrete']

harmonic_funcs = np.load('results/harmonic_funcs_nt=15.npy')

#Precomputations:
lhs_solves = []
stiffs_solve = [] # include also t=0 for stiffs:
for t in range(nt+1):
    stiffness_temp = stiffs_coeff_in_time_bc[t]
    stiffness_temp_solve = scipy.sparse.linalg.factorized(stiffness_temp.tocsc())
    stiffs_solve.append(stiffness_temp_solve)
    if t >= 1:
        matrix_left_temp = mass_bc + ht * stiffness_temp
        lhs_solve_temp = scipy.sparse.linalg.factorized(matrix_left_temp.tocsc())
        lhs_solves.append(lhs_solve_temp)

###############

start_int = 10
end_int = 25
nt_rand = end_int-start_int
max_basis_size = 15

# Generate random generator with specified seed:
random_seed = 0
rng = np.random.default_rng(seed = random_seed)

for iteration in range(2000):
    print(f'### Iteration {iteration+1}')
    # Assemble n_rand random initial conditions and make them smooth enough:
    u0_rand = rng.standard_normal(size=(mass_bc.shape[1],max_basis_size))
    for i in range(max_basis_size):
       u0_rand[:,i] = stiffs_solve[start_int](u0_rand[:,i])

    # Solve for random initial conditions and store solution evaluated after nt_rand time steps:
    random_basis = np.zeros((mass_bc.shape[1], max_basis_size))
    for i in range(max_basis_size):
        random_basis[:, i] = euler_method_sparse(u0_rand[:, i], grid_t[start_int], grid_t[end_int], nt_rand, mass_bc, lhs_solves[start_int:end_int], rhs_matrix[:, start_int:end_int + 1])[:, -1]

    gram_schmidt_ortho(random_basis, mass_bc)

    proj_errors = np.zeros(max_basis_size+1)
    proj_errors[0] = 5.08303496e-04

    for i in range(max_basis_size):
        print(f'Basis size {i+1}')
        basis_temp = random_basis[:,:(i+1)]
        matrix_temp = harmonic_funcs - basis_temp.dot(basis_temp.T.dot(mass_bc.dot(harmonic_funcs)))
        matrix_left = matrix_temp.T.dot(mass_bc.dot(matrix_temp))
        tic = time.time()
        max_eigval = scipy.linalg.eigh(matrix_left,mass_bc.todense(),eigvals_only=True,subset_by_index=[mass_bc.shape[0]-1,mass_bc.shape[0]-1])
        toc = time.time()
        print(f'Max sval: {np.sqrt(max_eigval)} computed in {(toc-tic)/60} minutes.')
        proj_errors[i+1] = np.sqrt(max_eigval)

    np.save(f'results/proj_errors_{iteration}.npy',proj_errors)

