import numpy as np
from common.euler_method import euler_method_sparse

### This function generates a random reduced basis, where basis functions are constructed from solutions
### of multiple transfer operators with random initial conditions and random (starting/ending) points in time
### chosen from a given probability distribution.


def random_basis_generation(rng, prob, mass, lhs_solves, stiffs_solve,grid_t, rhs_matrix, u0, n_rand, k, nt_rand, nt_init, tol, two_probs=False, prob_2 = None, n_rand_2=None):
    # Draw n_rand end points in time from the given probability distribution:
    rand_ints = rng.choice(len(grid_t),n_rand,p=prob)
    if two_probs == True:
        rand_ints_2 = rng.choice(len(grid_t),n_rand_2,p=prob_2)
        rand_ints = np.hstack((rand_ints,rand_ints_2))

    # Be sure that there is enough space to go backwards in time:
    # (otherwise -> concluded in computation for u_0 since n_init is chosen as big as nt_rand)
    rand_ints = rand_ints[np.where(rand_ints >= nt_rand)]
    rand_ends = grid_t[rand_ints]
    rand_starts = grid_t[rand_ints - nt_rand]
    n_rand_new = len(rand_ends)

    # Assemble n_rand random initial conditions and make them smooth enough:
    u0_rand = rng.standard_normal(size=(mass.shape[1],n_rand_new))
    for i in range(n_rand_new):
        u0_rand[:,i] = stiffs_solve[rand_ints[i]-nt_rand](u0_rand[:,i])

    # Solve for randomly chosen starting time points and store solution evaluated after k to nt_rand time steps:
    aux = nt_rand-k+1 # number of snapshots collected in each run
    solutions_rand = np.zeros((mass.shape[1],n_rand_new*aux+nt_init))
    for i in range(n_rand_new):
        start_temp = rand_starts[i]
        end_temp = rand_ends[i]
        solutions_rand[:,i*aux:(i+1)*aux] = euler_method_sparse(u0_rand[:, i], start_temp, end_temp, nt_rand, mass, lhs_solves[rand_ints[i]-nt_rand:rand_ints[i]], rhs_matrix[:,rand_ints[i]-nt_rand:rand_ints[i]+1])[:, k:]

    # Add representation for actual initial conditions:
    ht = grid_t[1]-grid_t[0]
    solutions_rand[:,-nt_init:] = euler_method_sparse(u0, grid_t[0], grid_t[0] + nt_init * ht, nt_init, mass, lhs_solves, rhs_matrix)[:, :nt_init]

    # SVD of all collected snapshots and cut using given tolerance:
    U_rand, singular_vals_rand, _ = np.linalg.svd(solutions_rand)
    for i in range(len(singular_vals_rand)):
        C = np.sqrt(np.sum(singular_vals_rand[i+1:]**2))/np.sqrt(np.sum(singular_vals_rand**2))
        if C <= tol:
            red_dim = i+1
            break
    red_basis_rand = U_rand[:,:red_dim]

    return rand_ends, singular_vals_rand, red_basis_rand



### Split computations into (rhs f + zero local initial condition) + (rhs 0 + random local initial condition):

def random_basis_generation_split(rng, prob, mass, lhs_solves, stiffs_solve,grid_t, rhs_matrix, u0, n_rand, n_initials, k, nt_rand, nt_init, tol, two_probs=False, prob_2 = None, n_rand_2=None):
    # Draw n_rand end points in time from the given probability distribution:
    rand_ints = rng.choice(len(grid_t),n_rand,p=prob)
    if two_probs == True:
        rand_ints_2 = rng.choice(len(grid_t),n_rand_2,p=prob_2)
        rand_ints = np.hstack((rand_ints,rand_ints_2))
    # Be sure that there is enough space to go backwards in time:
    # (otherwise -> concluded in computation for u_0 since n_init is chosen as big as nt_rand)
    rand_ints = rand_ints[np.where(rand_ints >= nt_rand)]
    rand_ends = grid_t[rand_ints]
    rand_starts = grid_t[rand_ints - nt_rand]
    n_rand_new = len(rand_ends)

    # Assemble n_rand random initial conditions and make them smooth enough:
    u0_rand = rng.standard_normal(size=(mass.shape[1],n_rand_new*n_initials))
    for i in range(n_rand_new):
        u0_rand[:,i] = stiffs_solve[rand_ints[i]-nt_rand](u0_rand[:,i])

    # Solve for randomly chosen starting time points and store solution evaluated after k to nt_rand time steps:
    # (splitted in computations for (general f and zero u_0) + (f=0 and random u_0))
    aux = nt_rand-k+1 # number of snapshots collected in each run
    solutions_rand = np.zeros((mass.shape[1],n_rand_new*(n_initials+1)*aux+nt_init))
    # Auxiliary variables:
    rhs_matrix_aux = np.zeros((rhs_matrix.shape[0],nt_rand+1))
    initial_zeros = np.zeros(mass.shape[1])
    for i in range(n_rand_new):
        start_temp = rand_starts[i]
        end_temp = rand_ends[i]
        for l in range(n_initials):
            solutions_rand[:,(i+l*n_rand_new)*aux:(i+l*n_rand_new+1)*aux] = euler_method_sparse(u0_rand[:,i+l*n_rand_new],start_temp,end_temp,nt_rand,mass,lhs_solves[rand_ints[i]-nt_rand:rand_ints[i]],rhs_matrix_aux)[:,k:]
        solutions_rand[:,(i+n_rand_new*n_initials)*aux:(i+n_rand_new*n_initials+1)*aux] = euler_method_sparse(initial_zeros,start_temp,end_temp,nt_rand,mass,lhs_solves[rand_ints[i]-nt_rand:rand_ints[i]],rhs_matrix[:,rand_ints[i]-nt_rand:rand_ints[i]+1])[:,k:]

    # Add representation for actual initial conditions:
    ht = grid_t[1]-grid_t[0]
    solutions_rand[:,-nt_init:] = euler_method_sparse(u0, grid_t[0], grid_t[0] + nt_init * ht, nt_init, mass, lhs_solves, rhs_matrix)[:, :nt_init]

    # SVD of all collected snapshots and cut using given tolerance:
    U_rand, singular_vals_rand, _ = np.linalg.svd(solutions_rand)
    for i in range(len(singular_vals_rand)):
        C = np.sqrt(np.sum(singular_vals_rand[i+1:]**2))/np.sqrt(np.sum(singular_vals_rand**2))
        if C <= tol:
            red_dim = i+1
            break
    red_basis_rand = U_rand[:,:red_dim]

    return rand_ends, singular_vals_rand, red_basis_rand