import numpy as np
import scipy.sparse.linalg
import sys

from assembly import *

def xyslice(nx, ny, xrang, yrang):
    """
    helper function,
    gives dofs in a range
    """
    return np.array(
        [xpos + ypos * (nx+1) 
         for ypos in range(*yrang)
         for xpos in range(*xrang)])

def yslice(nx, ny, pos):
    """
    helper function,
    gives dofs at one x position
    """
    return xyslice(nx, ny, (pos, pos+1), (0, ny+1))

def ldofs(nx, ny):
    """dofs on left boundary"""
    return yslice(nx, ny, 0)

def rdofs(nx, ny):
    """dofs on right boundary"""
    return yslice(nx, ny, nx)

def cdofs(nx, ny):
    """dofs at center of domain"""
    assert nx%2 == 0
    return yslice(nx, ny, nx/2)


def transferoperator_interfaces(h_inv,nx,ny):
    """returns
    (transferoperator, source_product, range_product)
    """
    print "computing transfer operator for nx={} ny={}".format(nx,ny),
    sys.stdout.flush()
    numdofs = (nx+1) * (ny+1)
    dirichletdofs = np.concatenate((ldofs(nx,ny),rdofs(nx,ny)))
    alldofs = np.array(range(numdofs))
    innerdofs = np.setdiff1d(alldofs, dirichletdofs)

    # mapping from old to new dof numbers
    newdof = np.zeros((numdofs,), dtype=np.int)
    newdof[innerdofs] = range(len(innerdofs))

    # get dof numbers on inner interface
    centraldofs = cdofs(nx,ny)
    # convert to new dof numbering
    centraldofs = newdof[centraldofs]

    # generate operator
    fulloperator = assembly(nx,ny)
    # remove dirichlet dofs
    operator = fulloperator[:,innerdofs][innerdofs,:]
    # create factorization
    import time
    start = time.time()
    size = operator.shape
    operator = scipy.sparse.linalg.factorized(operator)
    end = time.time()
    print("factorization of {} matrix in {}".format(size, end-start))
    # operator for right hand side
    rhsop = fulloperator[:,dirichletdofs][innerdofs,:]

    transferoperator = - np.array(operator(rhsop.todense()))[centraldofs,:]

    # now generate inner products
    slice_inner = assembly_l2_1d(ny, 1./h_inv)
    range_inner = slice_inner
    source_inner = scipy.sparse.bmat([[slice_inner,None],[None,slice_inner]], format="csc", dtype=np.float64)

    print "done"
    return (transferoperator, source_inner, range_inner)

def average_remover(size):
    return np.identity(size) - np.ones((size, size)) / float(size)

def transferoperator_volumes(nx,ny):
    """returns
    (transferoperator, source_product, range_product)
    """
    assert nx%3 == 0
    assert ny%3 == 0

    print "computing volume transfer operator for nx={} ny={}".format(nx,ny),
    sys.stdout.flush()
    numdofs = (nx+1) * (ny+1)
    innerdofs = xyslice(nx, ny, (1,nx), (1,ny))
    alldofs = np.array(range(numdofs))
    dirichletdofs = np.setdiff1d(alldofs, innerdofs)
    assert np.all(dirichletdofs == np.sort(dirichletdofs))

    # mapping from old to new dof numbers
    newdof = np.zeros((numdofs,), dtype=np.int)
    newdof[innerdofs] = range(len(innerdofs))

    # get dof numbers on inner interface
    centraldofs = xyslice(nx, ny, (nx/3, nx/3*2+1), (ny/3, ny/3*2+1))
    # convert to new dof numbering
    centraldofs = newdof[centraldofs]

    #alternative calculation
    centraldofs_check = xyslice(nx-2, ny-2, (nx/3-1, nx/3*2), (ny/3-1, ny/3*2))
    assert np.all(centraldofs == centraldofs_check)

    # generate operator
    fulloperator = assembly(nx,ny)
    # remove dirichlet dofs
    operator = fulloperator[:,innerdofs][innerdofs,:]
    # create factorization
    import time
    start = time.time()
    size = operator.shape
    operator = scipy.sparse.linalg.factorized(operator)
    end = time.time()
    print("factorization of {} matrix in {}".format(size, end-start))
    # operator for right hand side
    rhsop = fulloperator[:,dirichletdofs][innerdofs,:]

    transferoperator = - operator(rhsop.todense())[centraldofs,:]

    # and control constant function:
    ar = average_remover(transferoperator.shape[0])
    transferoperator = ar.dot(transferoperator)

    # now generate inner products
    tsize = 6.
    # source
    source_inner = scipy.sparse.dok_matrix((numdofs, numdofs), dtype=np.float64)
    #bottom
    source_inner[:nx+1, :nx+1] += assembly_l2_1d(nx, tsize/nx)
    #top
    source_inner[ny*(nx+1):, ny*(nx+1):] += assembly_l2_1d(nx, tsize/nx)
    #left
    source_inner[::(nx+1), ::(nx+1)] += assembly_l2_1d(ny, tsize/ny)
    #right
    source_inner[nx::(nx+1), nx::(nx+1)] += assembly_l2_1d(ny, tsize/ny)

    source_inner = source_inner.tocsc()
    source_inner = source_inner[:, dirichletdofs][dirichletdofs, :]
    range_inner = assembly(nx/3, ny/3)
    print "done"
    return (transferoperator, source_inner, range_inner)


if __name__ == "__main__":
    assert np.all(xyslice(3, 3, (1,3), (1,3)) == np.array([5,6,9,10]))
    assert np.all(yslice(5,7,3) == xyslice(5,7,(3,4), (0, 7+1)))
