import numpy as np
import scipy
import timeit

from pymor.basic import *

# variant 1: basis change
def create_orthogonal_basis(M):
    assert M.source.dim == M.range.dim
    basis = NumpyVectorArray(np.eye(M.source.dim))
    gram_schmidt(basis, copy=False, product=M)
    return basis

def operator_svd1(op, source_product, range_product):

    basis_source = create_orthogonal_basis(source_product)
    basis_range = create_orthogonal_basis(range_product)

    op_ownbasis = basis_range.dot(range_product.apply(op.apply(basis_source)))
    Uob, sob, Vob = np.linalg.svd(op_ownbasis, full_matrices=False)

    U1 = NumpyVectorArray(basis_range.data.T.dot(Uob).T)
    V1 = NumpyVectorArray(basis_source.data.T.dot(Vob.T).T)

    return U1, sob, V1

# variant 2: 
def operator_svd2(T, source_product, range_product):
    TadjT = T.apply_adjoint(T.apply(NumpyVectorArray(np.eye(T.source.dim))),
                        source_product=source_product,
                        range_product=range_product)

    w,v = np.linalg.eig(TadjT.data.T)
    mylist = [(i,lamb,v[:,i]) for i,lamb in enumerate(w)]

    mylist.sort(key=lambda x: -x[1])
    mylist = mylist[:T.range.dim]
    mylist = [(np.sqrt(j),k) for i,j,k in mylist]

    s2 = [x[0] for x in mylist]

    V2 = NumpyVectorArray(np.array([x[1] for x in mylist]))
    normssq = source_product.pairwise_apply2(V2,V2)
    V2.scal(1./np.sqrt(normssq))

    U2 = T.apply(V2)
    normssq = range_product.pairwise_apply2(U2,U2)
    U2.scal(1./np.sqrt(normssq))

    return U2, s2, V2


# variant 3
def operator_svd3(T, source_product, range_product):
    TTadj = T.apply(T.apply_adjoint(NumpyVectorArray(np.eye(T.range.dim)),
                                source_product=source_product,
                                range_product=range_product))

    w, v = np.linalg.eig(TTadj.data.T)
    mylist = [(i,lamb,v[:,i]) for i,lamb in enumerate(w)]

    mylist.sort(key=lambda x: -x[1])
    mylist = mylist[:T.range.dim]
    mylist = [(np.sqrt(j),k) for i,j,k in mylist]
    
    s3 = [x[0] for x in mylist]
    U3 = NumpyVectorArray(np.array([x[1] for x in mylist]))
    normssq = range_product.pairwise_apply2(U3,U3)
    U3.scal(1./np.sqrt(normssq))
    
    V3 = T.apply_adjoint(U3,
                         source_product=source_product,
                         range_product=range_product)
    normssq = source_product.pairwise_apply2(V3,V3)
    V3.scal(1./np.sqrt(normssq))
    
    return U3, s3, V3

# variant 4
def operator_svd4(T, source_product, range_product):
    TadjT = T.apply_adjoint(T.apply(NumpyVectorArray(np.eye(T.source.dim))),
                            range_product=range_product)

    smat = source_product._matrix
    if scipy.sparse.issparse(smat):
        smat = smat.todense()
    w,v = scipy.linalg.eig(TadjT.data.T, b=smat)
    mylist = [(i,lamb,v[:,i]) for i,lamb in enumerate(w)]

    mylist.sort(key=lambda x: -x[1])
    mylist = mylist[:T.range.dim]
    mylist = [(np.sqrt(j),k) for i,j,k in mylist]

    s2 = [x[0] for x in mylist]

    V2 = NumpyVectorArray(np.array([x[1] for x in mylist]))
    normssq = source_product.pairwise_apply2(V2,V2)
    V2.scal(1./np.sqrt(normssq))

    U2 = T.apply(V2)
    normssq = range_product.pairwise_apply2(U2,U2)
    U2.scal(1./np.sqrt(normssq))

    return U2, s2, V2

def operator_svd5(Top, source_inner, range_inner):
    Top = Top._matrix
    source_inner = source_inner._matrix
    range_inner = range_inner._matrix
    sfac = scipy.sparse.linalg.factorized(source_inner)
    Tadj = sfac(Top.T.dot(range_inner.todense()))
    blockmat = [[None, Tadj], [Top, None]]
    fullblockmat = scipy.sparse.bmat(blockmat).tocsc()
    w,v = np.linalg.eig(fullblockmat.todense())
    return np.abs(w[::2])


def random_inner_product(dim):
    M = np.random.random((dim,dim))*2-1
    M = M + M.T
    M = M.dot(M)
    return NumpyMatrixOperator(M)

def check_svd(fun):
    Ms = random_inner_product(dim_source)
    Mr = random_inner_product(dim_range)
    T = NumpyMatrixOperator(np.random.random((dim_range, dim_source))*2-1)
    U,s,V = fun(T, Ms, Mr)
    T_shouldbe = U.data.T.dot(np.diag(s).dot(V.data.dot(Ms._matrix)))

    return np.linalg.norm(T._matrix - T_shouldbe) / np.linalg.norm(T._matrix)

if __name__ == "__main__":
    dim_source = 700
    dim_range = 40
    tolerance = 1e-10*dim_source*dim_range
    setup = """
from operator_svd import random_inner_product
from pymor.basic import NumpyMatrixOperator
import numpy as np
from operator_svd import operator_svd1, operator_svd2, operator_svd3, operator_svd4
dim_source = {}
dim_range = {}
Ms = random_inner_product(dim_source)
Mr = random_inner_product(dim_range)
T = NumpyMatrixOperator(np.random.random((dim_range, dim_source))*2-1)
tolerance = {}
""".format(dim_source, dim_range, tolerance)
    times = []
    #times.append(timeit.timeit("operator_svd1(T, Ms, Mr)", setup=setup, number=2))
    #times.append(timeit.timeit("operator_svd2(T, Ms, Mr)", setup=setup, number=2))
    times.append(timeit.timeit("operator_svd3(T, Ms, Mr)", setup=setup, number=2))
    times.append(timeit.timeit("operator_svd4(T, Ms, Mr)", setup=setup, number=2))

    print times

    #print check_svd(operator_svd1)
    #print check_svd(operator_svd2)
    print check_svd(operator_svd3)
    print check_svd(operator_svd4)

