import scipy.sparse
import scipy.integrate
import numpy as np

### This file contains functions to assemble various mass and stiffness matrices and other useful functions/algorithms.
# counting in 2d: across x-rows and bottom-up
# example for nx=ny=2:  6  7  8
#                       3  4  5
#                       0  1  2


def mass_1d(domain,n):
    # computes 1d mass matrix for linear FEs
    # n = number of elements (number of dofs: n+1), domain = [a,b] 1d domain
    h = (domain[1]-domain[0])/n
    diag = 2/3*np.ones(n-1)
    aux = np.array([1/3])
    diag = np.concatenate((aux,diag,aux),axis=None)
    diag_minor = 1/6*np.ones(n)
    matrix = h*scipy.sparse.diags([diag_minor,diag,diag_minor],[-1,0,1])
    return matrix.tocsr()


def mass_bc_1d(domain,n):
    # computes 1d mass matrix with homogeneous Dirichlet boundary conditions for linear FEs
    # n = number of elements (number of dofs: n-1), domain = [a,b] 1d domain
    h = (domain[1]-domain[0])/n
    diag = 2/3*np.ones(n-1)
    diag_minor = 1/6*np.ones(n-2)
    matrix = h*scipy.sparse.diags([diag_minor,diag,diag_minor],[-1,0,1])
    return matrix.tocsr()


def stiff_1d(domain,n):
    # computes 1d stiffness matrix for linear FEs
    # n = number of elements (number of dofs: n+1), domain = [a,b] 1d domain
    h = (domain[1]-domain[0])/n
    diag = 2*np.ones(n-1)
    aux = np.array([1])
    diag = np.concatenate((aux,diag,aux),axis=None)
    diag_minor = -np.ones(n)
    matrix = 1/h*scipy.sparse.diags([diag_minor,diag,diag_minor],[-1,0,1])
    return matrix.tocsr()


def stiff_bc_1d(domain,n):
    # computes 1d stiffness matrix with homogeneous Dirichlet boundary conditions for linear FEs
    # n = number of elements (number of dofs: n-1), domain = [a,b] 1d domain
    h = (domain[1]-domain[0])/n
    diag = 2*np.ones(n-1)
    diag_minor = -np.ones(n-2)
    matrix = 1/h*scipy.sparse.diags([diag_minor,diag,diag_minor],[-1,0,1])
    return matrix.tocsr()


def advec_1d(n,coeff):
    # computes 1d advection matrix for linear FEs
    # n = number of elements (number of dofs: n+1), coeff = constant advection  coefficient
    diag = np.zeros(n+1)
    diag[0] = -1
    diag[-1] = 1
    diag_up = np.ones(n)
    diag_down = - np.ones(n)
    matrix = 0.5 * coeff * scipy.sparse.diags([diag_down,diag,diag_up],[-1,0,1])
    return matrix.tocsr()


def advec_bc_1d(n,coeff):
    # computes 1d advection matrix with homogeneous Dirichlet boundary conditions for linear FEs
    # n = number of elements (number of dofs: n-1), coeff = constant advection coefficient
    diag = np.zeros(n-1)
    diag_up = np.ones(n-2)
    diag_down = - np.ones(n-2)
    matrix = 0.5 * coeff * scipy.sparse.diags([diag_down,diag,diag_up],[-1,0,1])
    return matrix.tocsr()


def mass_2d(domain, nx, ny):
    # computes 2d mass matrix for linear FEs
    # nx/ny = number of elements in x/y direction (number of dofs: (Nx+1)*(Ny+1))
    # domain = [[a_x,b_x],[a_y,b_y]] 2d domain
    domain_x = [domain[0][0],domain[0][1]]
    domain_y = [domain[1][0],domain[1][1]]
    xmatrix = mass_1d(domain_x, nx)
    ymatrix = mass_1d(domain_y, ny)
    matrix = scipy.sparse.kron(ymatrix,xmatrix)
    return matrix.tocsr()


def mass_bc_2d(domain, nx, ny):
    # computes 2d mass matrix with homogeneous Dirichlet boundary conditions for linear FEs
    # nx/ny = number of elements in x/y direction (number of dofs: (Nx-1)*(Ny-1))
    # domain = [[a_x,b_x],[a_y,b_y]] 2d domain
    domain_x = [domain[0][0],domain[0][1]]
    domain_y = [domain[1][0],domain[1][1]]
    xmatrix = mass_bc_1d(domain_x, nx)
    ymatrix = mass_bc_1d(domain_y, ny)
    matrix = scipy.sparse.kron(ymatrix,xmatrix)
    return matrix.tocsr()


def stiff_2d(domain, nx, ny):
    # computes 2d stiffness matrix for linear FEs
    # nx/ny = number of elements in x/y direction (number of dofs: (Nx+1)*(Ny+1))
    # domain = [[a_x,b_x],[a_y,b_y]] 2d domain
    domain_x = [domain[0][0], domain[0][1]]
    domain_y = [domain[1][0], domain[1][1]]
    xmatrix_stiff = stiff_1d(domain_x, nx)
    xmatrix_mass = mass_1d(domain_x, nx)
    ymatrix_stiff = stiff_1d(domain_y, ny)
    ymatrix_mass = mass_1d(domain_y, ny)
    matrix = scipy.sparse.kron(ymatrix_stiff,xmatrix_mass)+scipy.sparse.kron(ymatrix_mass,xmatrix_stiff)
    return matrix.tocsr()


def stiff_bc_2d(domain, nx, ny):
    # computes 2d stiffness matrix with homogeneous Dirichlet boundary conditions for linear FEs
    # nx/ny = number of elements in x/y direction (number of dofs: (Nx-1)*(Ny-1))
    # domain = [[a_x,b_x],[a_y,b_y]] 2d domain
    domain_x = [domain[0][0], domain[0][1]]
    domain_y = [domain[1][0], domain[1][1]]
    xmatrix_stiff = stiff_bc_1d(domain_x, nx)
    xmatrix_mass = mass_bc_1d(domain_x, nx)
    ymatrix_stiff = stiff_bc_1d(domain_y, ny)
    ymatrix_mass = mass_bc_1d(domain_y, ny)
    matrix = scipy.sparse.kron(ymatrix_stiff,xmatrix_mass)+scipy.sparse.kron(ymatrix_mass,xmatrix_stiff)
    return matrix.tocsr()


def advec_2d(domain,nx,ny,coeff_x,coeff_y):
    # computes 2d advection matrix for linear FEs
    # nx/ny = number of elements in x/y direction (number of dofs: (Nx+1)*(Ny+1))
    # domain = [[a_x,b_x],[a_y,b_y]] 2d domain
    # coeff_x, coeff_y = constant advection coefficient in x/y direction
    domain_x = [domain[0][0], domain[0][1]]
    domain_y = [domain[1][0], domain[1][1]]
    xmatrix_mass = mass_1d(domain_x, nx)
    ymatrix_mass = mass_1d(domain_y, ny)
    xmatrix_advec = advec_1d(nx,coeff_x)
    ymatrix_advec = advec_1d(ny,coeff_y)
    matrix = scipy.sparse.kron(ymatrix_mass,xmatrix_advec) + scipy.sparse.kron(ymatrix_advec,xmatrix_mass)
    return matrix.tocsr()


def advec_bc_2d(domain,nx,ny,coeff_x,coeff_y):
    # computes 2d advection matrix with homogeneous Dirichlet boundary conditions for linear FEs
    # nx/ny = number of elements in x/y direction (number of dofs: (Nx-1)*(Ny-1))
    # domain = [[a_x,b_x],[a_y,b_y]] 2d domain
    # coeff_x, coeff_y = constant advection coefficient in x/y direction
    domain_x = [domain[0][0], domain[0][1]]
    domain_y = [domain[1][0], domain[1][1]]
    xmatrix_mass = mass_bc_1d(domain_x, nx)
    ymatrix_mass = mass_bc_1d(domain_y, ny)
    xmatrix_advec = advec_bc_1d(nx,coeff_x)
    ymatrix_advec = advec_bc_1d(ny,coeff_y)
    matrix = scipy.sparse.kron(ymatrix_mass,xmatrix_advec) + scipy.sparse.kron(ymatrix_advec,xmatrix_mass)
    return matrix.tocsr()



##########


def mass_coeff_1d(domain, n, alpha):
    # computes 1d mass matrix including coefficient for linear FEs
    # n = number of elements (number of dofs: n+1)
    # alpha(x) = coefficient (1D)
    # domain = [a,b] 1d domain
    h = (domain[1] - domain[0]) / n
    diag = np.zeros(n+1)
    diag_minor = np.zeros(n)

    for i in range(n):
        def func_down(x): return alpha(x) *( ((domain[0]+(i+1)*h)-x) / h)**2
        integral_down = scipy.integrate.quad(func_down,domain[0]+i*h,domain[0]+(i+1)*h,epsabs = 1e-12, epsrel = 1e-12)[0]
        diag[i] += integral_down
        def func_up(x): return alpha(x)*( (x-(domain[0]+i*h)) / h )**2
        integral_up = scipy.integrate.quad(func_up,domain[0]+i*h,domain[0]+(i+1)*h,epsabs = 1e-12, epsrel = 1e-12)[0]
        diag[i+1] += integral_up
        def func_mixed(x): return alpha(x) * ( ((domain[0]+(i+1)*h)-x) / h) * ( (x-(domain[0]+i*h)) / h )
        integral_mixed = scipy.integrate.quad(func_mixed,domain[0]+i*h,domain[0]+(i+1)*h,epsabs = 1e-12, epsrel = 1e-12)[0]
        diag_minor[i] += integral_mixed

    diags=[diag_minor,diag,diag_minor]
    matrix = scipy.sparse.diags(diags,[-1,0,1])
    return matrix.tocsr()


def stiff_coeff_1d(domain, n, alpha):
    # computes 1d stiffness matrix including coefficient for linear FEs
    # n = number of elements (number of dofs: n+1)
    # alpha(x) = coefficient (1D)
    # domain = [a,b] 1d domain
    h = (domain[1] - domain[0]) / n
    diag = np.zeros(n+1)
    diag_minor = np.zeros(n)

    for i in range(n):
        def func_down(x): return alpha(x) *( -1 / h)**2
        integral_down = scipy.integrate.quad(func_down,domain[0]+i*h,domain[0]+(i+1)*h,epsabs = 1e-12, epsrel = 1e-12)[0]
        diag[i] += integral_down
        def func_up(x): return alpha(x)*( 1 / h )**2
        integral_up = scipy.integrate.quad(func_up,domain[0]+i*h,domain[0]+(i+1)*h,epsabs = 1e-12, epsrel = 1e-12)[0]
        diag[i+1] += integral_up
        def func_mixed(x): return alpha(x) * ( -1 / h) * ( 1 / h )
        integral_mixed = scipy.integrate.quad(func_mixed,domain[0]+i*h,domain[0]+(i+1)*h,epsabs = 1e-12, epsrel = 1e-12)[0]
        diag_minor[i] += integral_mixed

    diags=[diag_minor,diag,diag_minor]
    matrix = scipy.sparse.diags(diags,[-1,0,1])
    return matrix.tocsr()


def mass_coeff_bc_1d(domain, n, alpha):
    # computes 1d mass matrix including coefficient for linear FEs with homogeneous Dirichlet boundary conditions
    # n = number of elements (number of dofs: n+1)
    # alpha(x) = coefficient (1D)
    # domain = [a,b] 1d domain
    h = (domain[1] - domain[0]) / n
    diag = np.zeros(n+1)
    diag_minor = np.zeros(n)

    for i in range(n):
        def func_down(x): return alpha(x) *( ((domain[0]+(i+1)*h)-x) / h)**2
        integral_down = scipy.integrate.quad(func_down,domain[0]+i*h,domain[0]+(i+1)*h,epsabs = 1e-12, epsrel = 1e-12)[0]
        diag[i] += integral_down
        def func_up(x): return alpha(x)*( (x-(domain[0]+i*h)) / h )**2
        integral_up = scipy.integrate.quad(func_up,domain[0]+i*h,domain[0]+(i+1)*h,epsabs = 1e-12, epsrel = 1e-12)[0]
        diag[i+1] += integral_up
        def func_mixed(x): return alpha(x) * ( ((domain[0]+(i+1)*h)-x) / h) * ( (x-(domain[0]+i*h)) / h )
        integral_mixed = scipy.integrate.quad(func_mixed,domain[0]+i*h,domain[0]+(i+1)*h,epsabs = 1e-12, epsrel = 1e-12)[0]
        diag_minor[i] += integral_mixed

    diags=[diag_minor[1:-1],diag[1:-1],diag_minor[1:-1]]
    matrix = scipy.sparse.diags(diags,[-1,0,1])
    return matrix.tocsr()


def stiff_coeff_bc_1d(domain, n, alpha):
    # computes 1d stiffness matrix including coefficient for linear FEs with homogeneous Dirichlet boundary conditions
    # n = number of elements (number of dofs: n+1)
    # alpha(x) = coefficient (1D)
    # domain = [a,b] 1d domain
    h = (domain[1] - domain[0]) / n
    diag = np.zeros(n+1)
    diag_minor = np.zeros(n)

    for i in range(n):
        def func_down(x): return alpha(x) *( -1 / h)**2
        integral_down = scipy.integrate.quad(func_down,domain[0]+i*h,domain[0]+(i+1)*h,epsabs = 1e-12, epsrel = 1e-12)[0]
        diag[i] += integral_down
        def func_up(x): return alpha(x)*( 1 / h )**2
        integral_up = scipy.integrate.quad(func_up,domain[0]+i*h,domain[0]+(i+1)*h,epsabs = 1e-12, epsrel = 1e-12)[0]
        diag[i+1] += integral_up
        def func_mixed(x): return alpha(x) * ( -1 / h) * ( 1 / h )
        integral_mixed = scipy.integrate.quad(func_mixed,domain[0]+i*h,domain[0]+(i+1)*h,epsabs = 1e-12, epsrel = 1e-12)[0]
        diag_minor[i] += integral_mixed

    diags=[diag_minor[1:-1],diag[1:-1],diag_minor[1:-1]]
    matrix = scipy.sparse.diags(diags,[-1,0,1])
    return matrix.tocsr()


def stiff_coeff_2d(domain, nx, ny, alpha_x, alpha_y):
    # computes 2d stiffness matrix including coefficient for linear FEs
    # nx/ny = number of elements in x/y direction (number of dofs: (Nx+1)*(Ny+1))
    # alpha(x,y)=alpha_x(x)*alpha_y(y) 2D coefficient in tensor format
    # domain = [[a_x,b_x],[a_y,b_y]] 2d domain
    domain_x = [domain[0][0], domain[0][1]]
    domain_y = [domain[1][0], domain[1][1]]
    xmatrix_stiff = stiff_coeff_1d(domain_x, nx, alpha_x)
    xmatrix_mass = mass_coeff_1d(domain_x, nx, alpha_x)
    ymatrix_stiff = stiff_coeff_1d(domain_y, ny, alpha_y)
    ymatrix_mass = mass_coeff_1d(domain_y, ny, alpha_y)
    matrix = scipy.sparse.kron(ymatrix_stiff,xmatrix_mass)+scipy.sparse.kron(ymatrix_mass,xmatrix_stiff)
    return matrix.tocsr()


def stiff_coeff_bc_2d(domain, nx, ny, alpha_x, alpha_y):
    # computes 2d stiffness matrix including coefficient for linear FEs with homogeneous Dirichlet boundary conditions
    # nx/ny = number of elements in x/y direction (number of dofs: (Nx+1)*(Ny+1))
    # alpha(x,y)=alpha_x(x)*alpha_y(y) 2D coefficient in tensor format
    # domain = [[a_x,b_x],[a_y,b_y]] 2d domain
    domain_x = [domain[0][0], domain[0][1]]
    domain_y = [domain[1][0], domain[1][1]]
    xmatrix_stiff = stiff_coeff_bc_1d(domain_x, nx, alpha_x)
    xmatrix_mass = mass_coeff_bc_1d(domain_x, nx, alpha_x)
    ymatrix_stiff = stiff_coeff_bc_1d(domain_y, ny, alpha_y)
    ymatrix_mass = mass_coeff_bc_1d(domain_y, ny, alpha_y)
    matrix = scipy.sparse.kron(ymatrix_stiff,xmatrix_mass)+scipy.sparse.kron(ymatrix_mass,xmatrix_stiff)
    return matrix.tocsr()


##########


def assemble_rhs_1d(domain, n, rhs):
    # computes rhs data_vector (rhs,\phi) for all \phi (linear FEs) with homogeneous Dirichlet boundary conditions
    # n = number of elements (number of dofs: n-1)
    # rhs(x) = rhs data function (1D)
    # domain = [a,b] 1d domain
    h = (domain[1]-domain[0]) / n
    rhs_vals = np.zeros(n+1)
    for i in range(n):
        def func_temp(x): return rhs(x) *( ((domain[0]+(i+1)*h)-x) / h)
        integral = scipy.integrate.quad(func_temp,domain[0]+i*h,domain[0]+(i+1)*h,epsabs = 1e-12, epsrel = 1e-12)[0]
        rhs_vals[i] += integral
        def func_temp_2(x): return rhs(x)*( (x-(domain[0]+i*h)) / h )
        integral_2 = scipy.integrate.quad(func_temp_2,domain[0]+i*h,domain[0]+(i+1)*h,epsabs = 1e-12, epsrel = 1e-12)[0]
        rhs_vals[i+1] += integral_2
    return rhs_vals[1:-1]


def assemble_rhs_2d(domain,nx,ny,rhs_x,rhs_y):
    # computes rhs data_vector (rhs,\phi) for all \phi (linear 2D FEs) with homogeneous Dirichlet boundary conditions
    # nx/ny = number of elements in x/y direction (number of dofs: (Nx-1)*(Ny-1))
    # domain = [[a_x,b_x],[a_y,b_y]] 2d domain
    # rhs_x/rhs_y = list of summands (in space) of affine decomposition of right hand side function
    size = len(rhs_x)
    rhs_vectors = np.zeros(((nx-1)*(ny-1),size))
    domain_x = [domain[0][0], domain[0][1]]
    domain_y = [domain[1][0], domain[1][1]]
    for i in range(size):
        vec_x = assemble_rhs_1d(domain_x,nx,rhs_x[i])
        vec_y = assemble_rhs_1d(domain_y,ny,rhs_y[i])
        rhs_vectors[:,i] = np.kron(vec_y,vec_x)
    # resulting x/y-vector in each column:
    return rhs_vectors


##########


def FE_assemble(domain,nx,ny,T_start,T_finish,nt,alpha_t,alpha_x,alpha_y,rhs_t,rhs_x,rhs_y):
    # returns a bundle of (2D) FE matrices etc
    mass = mass_2d(domain,nx,ny)
    mass_bc = mass_bc_2d(domain,nx,ny)
    stiff = stiff_2d(domain,nx,ny)
    stiff_bc = stiff_bc_2d(domain,nx,ny)
    grid_t = np.linspace(T_start,T_finish,nt+1)
    rhs_vecs = assemble_rhs_2d(domain,nx,ny,rhs_x,rhs_y)
    rhs_matrix = np.zeros((rhs_vecs.shape[0],len(grid_t)))
    for i in range(len(grid_t)):
        for j in range(rhs_vecs.shape[1]):
            rhs_matrix[:,i] += rhs_t[j](grid_t[i]) * rhs_vecs[:,j]
    stiffs_coeff_space = []
    stiffs_coeff_space_bc = []
    for i in range(len(alpha_x)):
        stiffs_coeff_space.append(stiff_coeff_2d(domain,nx,ny,alpha_x[i],alpha_y[i]))
        stiffs_coeff_space_bc.append(stiff_coeff_bc_2d(domain,nx,ny,alpha_x[i],alpha_y[i]))
    stiffs_coeff_in_time = []
    stiffs_coeff_in_time_bc = []
    for t in range(len(grid_t)):
        stiff_temp = alpha_t[0](grid_t[t])*stiffs_coeff_space[0]
        stiff_bc_temp = alpha_t[0](grid_t[t])*stiffs_coeff_space_bc[0]
        for i in range(*(1,len(alpha_t))):
            stiff_temp += alpha_t[i](grid_t[t])*stiffs_coeff_space[i]
            stiff_bc_temp += alpha_t[i](grid_t[t])*stiffs_coeff_space_bc[i]
        stiffs_coeff_in_time.append(stiff_temp)
        stiffs_coeff_in_time_bc.append(stiff_bc_temp)
    grid_x = np.linspace(domain[0][0],domain[0][1],nx+1)
    grid_y = np.linspace(domain[1][0],domain[1][1],ny+1)
    X,Y = np.meshgrid(grid_x,grid_y)
    alpha_matrix = np.zeros(((nx+1)*(ny+1),nt+1))
    for i in range(nt+1):
        for j in range(len(alpha_t)):
            alpha_matrix[:,i] += alpha_t[j](grid_t[i])*alpha_x[j](X.ravel())*alpha_y[j](Y.ravel())
    return grid_x,grid_y,grid_t,mass,mass_bc,stiff,stiff_bc,stiffs_coeff_in_time,stiffs_coeff_in_time_bc,rhs_matrix, alpha_matrix



def FE_assemble_incl_advec(domain,nx,ny,T_start,T_finish,nt,alpha_t,alpha_x,alpha_y,advec_x,advec_y,rhs_t,rhs_x,rhs_y):
    # returns a bundle of (2D) FE matrices etc
    mass = mass_2d(domain,nx,ny)
    mass_bc = mass_bc_2d(domain,nx,ny)
    stiff = stiff_2d(domain,nx,ny)
    stiff_bc = stiff_bc_2d(domain,nx,ny)
    advec = advec_2d(domain,nx,ny,advec_x,advec_y)
    advec_bc = advec_bc_2d(domain,nx,ny,advec_x,advec_y)
    grid_t = np.linspace(T_start,T_finish,nt+1)
    rhs_vecs = assemble_rhs_2d(domain,nx,ny,rhs_x,rhs_y)
    rhs_matrix = np.zeros((rhs_vecs.shape[0],len(grid_t)))
    for i in range(len(grid_t)):
        for j in range(rhs_vecs.shape[1]):
            rhs_matrix[:,i] += rhs_t[j](grid_t[i]) * rhs_vecs[:,j]
    stiffs_coeff_space = []
    stiffs_coeff_space_bc = []
    for i in range(len(alpha_x)):
        stiffs_coeff_space.append(stiff_coeff_2d(domain,nx,ny,alpha_x[i],alpha_y[i]))
        stiffs_coeff_space_bc.append(stiff_coeff_bc_2d(domain,nx,ny,alpha_x[i],alpha_y[i]))
    stiffs_coeff_in_time = []
    stiffs_coeff_in_time_bc = []
    for t in range(len(grid_t)):
        stiff_temp = alpha_t[0](grid_t[t])*stiffs_coeff_space[0]
        stiff_bc_temp = alpha_t[0](grid_t[t])*stiffs_coeff_space_bc[0]
        for i in range(*(1,len(alpha_t))):
            stiff_temp += alpha_t[i](grid_t[t])*stiffs_coeff_space[i]
            stiff_bc_temp += alpha_t[i](grid_t[t])*stiffs_coeff_space_bc[i]
        stiffs_coeff_in_time.append(stiff_temp)
        stiffs_coeff_in_time_bc.append(stiff_bc_temp)
    grid_x = np.linspace(domain[0][0],domain[0][1],nx+1)
    grid_y = np.linspace(domain[1][0],domain[1][1],ny+1)
    X,Y = np.meshgrid(grid_x,grid_y)
    alpha_matrix = np.zeros(((nx+1)*(ny+1),nt+1))
    for i in range(nt+1):
        for j in range(len(alpha_t)):
            alpha_matrix[:,i] += alpha_t[j](grid_t[i])*alpha_x[j](X.ravel())*alpha_y[j](Y.ravel())
    return grid_x,grid_y,grid_t,mass,mass_bc,stiff,stiff_bc,advec,advec_bc,stiffs_coeff_in_time,stiffs_coeff_in_time_bc,rhs_matrix, alpha_matrix


##########


def gram_schmidt_ortho(basis, product):
    # constructs orthonormal basis via Gram Schmidt algorithm
    for i in range(basis.shape[1]):
        oldnorm = np.float("inf")
        newnorm = np.sqrt(basis[:, i].T.dot(product.dot(basis[:, i])))
        while newnorm < 0.9 * oldnorm:
            oldnorm = newnorm
            for j in range(i):
                basis[:, i] -= basis[:, j] * basis[:, j].T.dot(product.dot(basis[:, i]))
            newnorm = np.sqrt(basis[:, i].T.dot(product.dot(basis[:, i])))

        basis[:, i] *= 1. / newnorm

