import numpy as np
import scipy.sparse
from common.assemble import mass_2d, stiff_2d, stiff_coeff_2d

local_stiff = np.array([[ 2./3., -1./6., -1./6., -1./3.],
                        [-1./6.,  2./3., -1./3., -1./6.],
                        [-1./6., -1./3.,  2./3., -1./6.],
                        [-1./3., -1./6., -1./6.,  2./3.]])

def stiff_coeffmatrix_2D(nx, ny, coeff_matrix):
    # computes 2d stiffness matrix including coefficient matrix for linear FEs
    # nx/ny = number of elements in x/y direction (number of dofs: (Nx+1)*(Ny+1))
    # coeff_matrix = matrix representing discrete coefficient (shape: ny x nx)
    # (entry i,j contains value of coefficient in i-th block in y and j-th block in x direction)
    # important assumption for correct assembly: hx = hy ! (step size in x/y direction)
    ndofs = (nx+1) * (ny+1)
    stiff = scipy.sparse.dok_matrix((ndofs,ndofs))
    for yelem in range(ny):
        for xelem in range(nx):
            globaldofs = (xelem + 0 + (yelem + 0) * (nx+1),
                          xelem + 1 + (yelem + 0) * (nx+1),
                          xelem + 0 + (yelem + 1) * (nx+1),
                          xelem + 1 + (yelem + 1) * (nx+1))
            for ldof1 in range(4):
                for ldof2 in range(4):
                    stiff[globaldofs[ldof1],globaldofs[ldof2]] += coeff_matrix[yelem,xelem] * local_stiff[ldof1, ldof2]
    return stiff.tocsr()


##########

def FE_assemble(domain, nx, ny, T_start, T_finish,nt, alpha_t, alpha_x, alpha_y, SPE10_coeff_matrix):
    # returns a bundle of (2D) FE matrices etc
    grid_x = np.linspace(domain[0][0],domain[0][1],nx+1)
    grid_y = np.linspace(domain[1][0],domain[1][1],ny+1)
    grid_t = np.linspace(T_start,T_finish,nt+1)

    # homogeneous Dirichlet boundary conditions at bottom (first nx+1 DOFS):
    mass = mass_2d(domain,nx,ny)[nx+1:,:][:,nx+1:]
    stiff = stiff_2d(domain,nx,ny)[nx+1:,:][:,nx+1:]

    # background rough coefficient that is added to channels below in each time step:
    stiff_matrix_coeff = stiff_coeffmatrix_2D(nx, ny, SPE10_coeff_matrix)

    stiffs_coeff_space = []
    for i in range(len(alpha_x)):
        stiffs_coeff_space.append(stiff_coeff_2d(domain,nx,ny,alpha_x[i],alpha_y[i]))
    stiffs_coeff_in_time = []
    for t in range(len(grid_t)):
        stiff_temp = alpha_t[0](grid_t[t])*stiffs_coeff_space[0]
        for i in range(*(1,len(alpha_t))):
            stiff_temp += alpha_t[i](grid_t[t])*stiffs_coeff_space[i]
        stiffs_coeff_in_time.append((stiff_temp+stiff_matrix_coeff)[nx+1:,:][:,nx+1:])

    alpha_matrix = np.zeros((nx*ny,nt+1))
    h= grid_x[1]-grid_x[0]
    grid_x_aux = np.linspace(domain[0][0] + h/2 ,domain[0][1] - h/2, nx)
    grid_y_aux = np.linspace(domain[1][0] + h/2 ,domain[1][1] - h/2, ny)
    X_aux, Y_aux = np.meshgrid(grid_x_aux,grid_y_aux)
    for i in range(nt+1):
        alpha_matrix[:,i] += SPE10_coeff_matrix.ravel()
        for j in range(len(alpha_t)):
            alpha_matrix[:,i] += alpha_t[j](grid_t[i])*alpha_x[j](X_aux.ravel())*alpha_y[j](Y_aux.ravel())

    return grid_x,grid_y,grid_t,mass,stiff,stiffs_coeff_in_time, alpha_matrix