import numpy as np
import scipy.sparse.linalg
import math

import scipy.sparse
from mybmat import mybmat
scipy.sparse.bmat = mybmat

def testlimit(failure_tolerance, dim_S, dim_R, num_testvecs, target_error, lambda_min):
    """
    failure_tolerance:  maximum probability for failure of algorithm
    dim_S: dimension of source space
    dim_R: dimension of range space
    num_testvecs: number of test vectors used
    target_error: desired maximal norm of tested operator
    lambda_min: smallest eigenvalue of matrix of inner product in source space
    """
    return scipy.special.erfinv( (failure_tolerance / min(dim_S, dim_R))**(1./num_testvecs)) * target_error * np.sqrt(2. * lambda_min)


def operator_svd(Top, source_inner, range_inner):
    sfac = scipy.sparse.linalg.factorized(source_inner)
    Tadj = sfac(Top.T.dot(range_inner.todense()))
    blockmat = [[None, Tadj], [Top, None]]
    fullblockmat = scipy.sparse.bmat(blockmat).tocsc()
    w,v = np.linalg.eig(fullblockmat.todense())
    v = np.array(v)
    indices = np.argsort(w.real)
    w = w[indices]
    v = v[:,indices]
    resultsize = min(Top.shape)
    range_vecs = v[source_inner.shape[0]:, :resultsize]
    source_vecs = v[:source_inner.shape[0], :resultsize]
    svals = np.abs(w[:resultsize])
    return range_vecs, svals, source_vecs

def adaptive_implicit_svd(Top, tolerance):

    def local_apply_top(array):
        assert array.shape[0] == Top.shape[0] + Top.shape[1]
        result = np.zeros_like(array)
        result[Top.shape[1]:] = Top.matvec(array[:Top.shape[1]])
        result[:Top.shape[1]] = Top.rmatvec(array[Top.shape[1]:])
        return result

    my_lop = scipy.sparse.linalg.LinearOperator(
        dtype=np.float,
        shape=(Top.shape[0] + Top.shape[1],Top.shape[0] + Top.shape[1]),
        matvec = local_apply_top
        )

    k = 10
    eigvals, eigvec = scipy.sparse.linalg.eigs(my_lop, k=k, tol=tolerance, which="LM")
    while np.min(np.abs(eigvals)) >= tolerance:
        k = int(math.ceil(k*2))        
        eigvals, eigvec = scipy.sparse.linalg.eigs(my_lop, k=k, tol=tolerance, which="LM")
    
    eigvals = eigvals.real
    indices = np.argsort(eigvals)[::-1]
    indices = indices[np.nonzero(eigvals[indices] >= tolerance)]
    
    return eigvec[Top.shape[1]:,indices], eigvals[indices], eigvec[:Top.shape[1],indices]
    
def operator_svd2(Top, source_inner, range_inner):
    mat_left = Top.T.dot(range_inner.dot(Top))
    mat_right = source_inner.todense()
    eigvals,v = scipy.linalg.eig(mat_left, mat_right)
    eigvals = np.sqrt(np.abs(eigvals.real))
    eigvals[::-1].sort()
    import pdb
    #pdb.set_trace()
    return None, eigvals, v

def gram_schmidt(basis, inner, start):
    assert len(basis.shape) == 2
    for i in range(start, basis.shape[1]):
        oldnorm = np.float("inf")
        newnorm = np.sqrt(basis[:,i].T.dot(inner.dot(basis[:,i])))
        startnorm = np.sqrt(basis[:,i].T.dot(inner.dot(basis[:,i])))
        while newnorm < 0.9 * oldnorm:
            oldnorm = newnorm
            for j in range(i):
                basis[:,i] -= basis[:,j] * basis[:,j].T.dot(inner.dot(basis[:,i]))
            newnorm = np.sqrt(basis[:,i].T.dot(inner.dot(basis[:,i])))

        basis[:,i] *= 1./newnorm

def range_generation(top, range_inner, source_inner, max_size=13, num_testvecs=10):
    testvecs = top.dot(scipy.random.normal(size=(top.shape[1],num_testvecs)))

    basis = np.zeros((range_inner.shape[0],0))
    remains = np.copy(top)
    opnorms = []
    testvecnorms = []
    for i in range(max_size):
        inners = remains.T.dot(range_inner.dot(remains))
        opnorm = np.sqrt(np.max(scipy.linalg.eigvalsh(inners, source_inner.todense())))
        opnorms.append(opnorm)
        testvec_norms = np.sqrt(np.diag(testvecs.T.dot(range_inner.dot(testvecs))))
        assert len(testvec_norms.shape) == 1
        testvecnorms.append(testvec_norms)

        # basis extension
        oldbasissize = basis.shape[1]
        basis = np.hstack((basis,top.dot(scipy.random.normal(size=(top.shape[1],1)))))
        gram_schmidt(basis, range_inner, oldbasissize)

        remains -= basis.dot(basis.T.dot(range_inner.dot(remains)))
        testvecs -= basis.dot(basis.T.dot(range_inner.dot(testvecs)))

    return opnorms, testvecnorms


def adaptive_range_generation(t_operator, range_inner, source_inner, num_testvecs, target_error):
    """
    @param t_operator   operator to approximate image of, given as numpy matrix
    @param range_inner  inner product in range space, given as numpy matrix
    @param source_inner inner product in source space, given as numpy matrix
    @param num_testvecs number of test vectors to use
    """
    testvecs = t_operator.matmat(scipy.random.normal(size=(t_operator.shape[1],num_testvecs)))
    lambda_min = np.min(scipy.sparse.linalg.eigs(source_inner, return_eigenvectors=False, which="SM")).real
    print("lambda_min is {}".format(lambda_min))

    basis = np.zeros((range_inner.shape[0],0))
    xi = testlimit(
        failure_tolerance=1e-15, 
        dim_S=t_operator.shape[1], 
        dim_R=t_operator.shape[0],
        num_testvecs=num_testvecs,
        target_error=target_error,
        lambda_min=lambda_min
        )
    while np.sqrt(np.max(np.diag(testvecs.T.dot(range_inner.dot(testvecs))))) > xi:
        basis = np.hstack((basis,t_operator.matvec(scipy.random.normal(size=(t_operator.shape[1],1)))))
        gram_schmidt(basis, range_inner, basis.shape[1]-1)
        testvecs -= basis.dot(basis.T.dot(range_inner.dot(testvecs)))

    return basis

