import scipy.sparse.linalg

from common.assemble import FE_assemble_incl_advec
from common.euler_method import euler_method_sparse
import numpy as np
import os
import pickle
import dill # supports pickling of functions

### We consider the ordinary differential equation: Find u(t) in R^n such that
### Mu_t(t) + Su(t) + Au(t) = f and u(0) = u_0. (Stiffness and advection)
### In this file we compute and store the high fidelity FE solution and FE matrices.



# Define test problem data:

# Right hand side: (zero in considered doamin from t=0 to t=0.15)
def rhs_t_1(t): return -50*(t-0.5)*(t-2)*(t>=0.5)*(t<=2)
def rhs_x_1(x): return (x>=0.1)*(x<=0.2)
def rhs_y_1(y): return (y>=0.1)*(y<=0.2)
rhs_t = [rhs_t_1]
rhs_x = [rhs_x_1]
rhs_y = [rhs_y_1]

# Diffusion Coefficient:
def alpha_t_1(t): return np.ones_like(t)
def alpha_x_1(x): return np.ones_like(x)
def alpha_y_1(y): return np.ones_like(y)
alpha_t = [alpha_t_1]
alpha_x = [alpha_x_1]
alpha_y = [alpha_y_1]

# Advection Coefficient:
# advec_x is determined below
advec_y = 0

# 2D domain with nx/ny elements in x/y direction:
domain = [[0,1],[0,0.3]]
nx = 300
ny = 90

# Time interval and number of time steps:
T_start = 0
T_finish = 0.1
nt = 10

for example_no in [1,2]:
    # Initial condition:
    if example_no == 1:
        def u0(x,y): return 100000000/6*(x-0.05)*(x-0.15)*(y-0.15)*(y-0.25) *(x>=0.05)*(x<=0.15)*(y>=0.15)*(y<=0.25)
    else:
        def u0(x,y): return 100000000/6*(x-0.4)*(x-0.5)*(y-0.05)*(y-0.15) *(x>=0.4)*(x<=0.5)*(y>=0.05)*(y<=0.15)

    for advec_x in [0,100]:
        test_problem = f'initial_example_{example_no}_advec_x={advec_x}'

        # Assemble FE matrices etc:
        grid_x, grid_y, grid_t, mass, mass_bc, stiff, stiff_bc, advec, advec_bc, stiffs_coeff_in_time, stiffs_coeff_in_time_bc, rhs_matrix, alpha_matrix = FE_assemble_incl_advec(domain, nx, ny, T_start, T_finish,nt, alpha_t, alpha_x, alpha_y,advec_x,advec_y, rhs_t,rhs_x, rhs_y)
        ht = grid_t[1]-grid_t[0]

        # Assemble discrete initial condition:
        X,Y = np.meshgrid(grid_x[1:-1],grid_y[1:-1])
        u_0 = u0(X,Y).flatten()

        # Compute high fidelity FE solution:
        lhs_solves = []
        for t in range(*(1,nt+1)):
            stiffness_temp = stiffs_coeff_in_time_bc[t]
            matrix_left_temp = mass_bc + ht * stiffness_temp + ht * advec_bc
            lhs_solve_temp = scipy.sparse.linalg.factorized(matrix_left_temp.tocsc())
            lhs_solves.append(lhs_solve_temp)

        solution = euler_method_sparse(u_0, T_start, T_finish, nt, mass_bc, lhs_solves, rhs_matrix)

        # Store all results:
        os.makedirs('results', exist_ok=True)
        results = {}
        results['FE_solution'] = solution
        results['grid_t'] = grid_t
        results['grid_x'] = grid_x
        results['grid_y'] = grid_y
        results['mass'] = mass
        results['mass_bc'] = mass_bc
        results['stiff'] = stiff
        results['stiff_bc'] = stiff_bc
        results['advec'] = advec
        results['advec_bc'] = advec_bc
        results['stiffs_coeff_in_time'] = stiffs_coeff_in_time
        results['stiffs_coeff_in_time_bc'] = stiffs_coeff_in_time_bc
        results['rhs_matrix'] = rhs_matrix
        results['alpha_matrix'] = alpha_matrix
        results['u0_discrete'] = u_0
        pickle.dump(results, open('results/FE_data_problem='+test_problem+'.pickle','wb'))
        dill.dump(rhs_t, open('results/rhs_t_problem='+test_problem+'.dill','wb'))