import scipy.sparse.linalg

from common.assemble import FE_assemble
from common.euler_method import euler_method_sparse
import numpy as np
import os
import pickle
import dill # supports pickling of functions

### We consider the ordinary differential equation: Find u(t) in R^n such that
### Mu_t(t) + Au(t) = f and u(0) = u_0.
### In this file we compute and store the high fidelity FE solution and FE matrices.


test_problem = 'stove'

# Define test problem data:

# Right hand side:
def rhs_t_1(t): return 10*(-t**2+4.5*t-2)*(t>=0.5)*(t<=4)
def rhs_x_1(x): return (x>=0.2)*(x<=0.3)
def rhs_y_1(y): return (y>=0.2)*(y<=0.3)
def rhs_t_2(t): return 5*(-t**2+10*t-21)*(t>=3)*(t<=7)
def rhs_x_2(x): return (x>=0.45)*(x<=0.55)
def rhs_y_2(y): return (y>=0.45)*(y<=0.55)
def rhs_t_3(t): return 10*(-t**2+15*t-54)*(t>=6)*(t<=9)
def rhs_x_3(x): return (x>=0.65)*(x<=0.8)
def rhs_y_3(y): return (y>=0.65)*(y<=0.8)
rhs_t = [rhs_t_1,rhs_t_2,rhs_t_3]
rhs_x = [rhs_x_1,rhs_x_2,rhs_x_3]
rhs_y = [rhs_y_1,rhs_y_2,rhs_y_3]

# Coefficient:
def alpha_t_1(t): return np.ones_like(t)
def alpha_x_1(x): return np.ones_like(x)
def alpha_y_1(y): return np.ones_like(y)
alpha_t = [alpha_t_1]
alpha_x = [alpha_x_1]
alpha_y = [alpha_y_1]

# Initial condition:
def u0(x,y): return np.sin(2*np.pi*x)*np.sin(2*np.pi*y) + np.sin(3*np.pi*x)*np.sin(3*np.pi*y) + \
                    np.sin(4*np.pi*x)*np.sin(4*np.pi*y)

# 2D domain with nx/ny elements in x/y direction:
domain = [[0,1],[0,1]]
nx = 100
ny = 100

# Time interval and number of time steps:
T_start = 0
T_finish = 10
nt = 300


# Assemble FE matrices etc:
grid_x, grid_y, grid_t, mass, mass_bc, stiff, stiff_bc, stiffs_coeff_in_time, stiffs_coeff_in_time_bc, rhs_matrix, alpha_matrix = FE_assemble(domain, nx, ny, T_start, T_finish,nt, alpha_t, alpha_x, alpha_y, rhs_t,rhs_x, rhs_y)
ht = grid_t[1]-grid_t[0]

# Assemble discrete initial condition:
X,Y = np.meshgrid(grid_x[1:-1],grid_y[1:-1])
u_0 = u0(X,Y).flatten()

# Compute high fidelity FE solution:
lhs_solves = []
for t in range(*(1,nt+1)):
    stiffness_temp = stiffs_coeff_in_time_bc[t]
    matrix_left_temp = mass_bc + ht * stiffness_temp
    lhs_solve_temp = scipy.sparse.linalg.factorized(matrix_left_temp.tocsc())
    lhs_solves.append(lhs_solve_temp)

solution = euler_method_sparse(u_0, T_start, T_finish, nt, mass_bc, lhs_solves, rhs_matrix)

# Store all results:
os.makedirs('results', exist_ok=True)
results = {}
results['FE_solution'] = solution
results['grid_t'] = grid_t
results['grid_x'] = grid_x
results['grid_y'] = grid_y
results['mass'] = mass
results['mass_bc'] = mass_bc
results['stiff'] = stiff
results['stiff_bc'] = stiff_bc
results['stiffs_coeff_in_time'] = stiffs_coeff_in_time
results['stiffs_coeff_in_time_bc'] = stiffs_coeff_in_time_bc
results['rhs_matrix'] = rhs_matrix
results['alpha_matrix'] = alpha_matrix
results['u0_discrete'] = u_0
pickle.dump(results, open('results/FE_data_problem='+test_problem+'.pickle','wb'))
dill.dump(rhs_t, open('results/rhs_t_problem='+test_problem+'.dill','wb'))