import numpy as np
import scipy.sparse.linalg

import scipy.sparse
from mybmat import mybmat
scipy.sparse.bmat = mybmat

def testlimit(failure_tolerance, dim_S, dim_R, num_testvecs, target_error, lambda_min):
    """
    failure_tolerance:  maximum probability for failure of algorithm
    dim_S: dimension of source space
    dim_R: dimension of range space
    num_testvecs: number of test vectors used
    target_error: desired maximal norm of tested operator
    lambda_min: smallest eigenvalue of matrix of inner product in source space
    """
    return scipy.special.erfinv( (failure_tolerance / min(dim_S, dim_R))**(1./num_testvecs)) * target_error * np.sqrt(2. * lambda_min)


def operator_svd(Top, source_inner, range_inner):
    sfac = scipy.sparse.linalg.factorized(source_inner)
    Tadj = sfac(Top.T.dot(range_inner.todense()))
    blockmat = [[None, Tadj], [Top, None]]
    fullblockmat = scipy.sparse.bmat(blockmat).tocsc()
    w,v = np.linalg.eig(fullblockmat.todense())
    return np.abs(w[::2]), v[:source_inner.shape[0],::2], v[source_inner.shape[0]:, ::2]

def operator_svd2(Top, source_inner, range_inner):
    mat_left = Top.T.dot(range_inner.dot(Top))
    mat_right = source_inner.todense()
    eigvals = scipy.linalg.eigvals(mat_left, mat_right)
    eigvals = np.sqrt(np.abs(eigvals.real))
    eigvals[::-1].sort()
    import pdb
    #pdb.set_trace()
    return eigvals, None, None

def gram_schmidt(basis, inner, start):
    assert len(basis.shape) == 2
    for i in range(start, basis.shape[1]):
        oldnorm = np.float("inf")
        newnorm = np.sqrt(basis[:,i].T.dot(inner.dot(basis[:,i])))
        while newnorm < 0.9 * oldnorm:
            oldnorm = newnorm
            for j in range(i):
                basis[:,i] -= basis[:,j] * basis[:,j].T.dot(inner.dot(basis[:,i]))
            newnorm = np.sqrt(basis[:,i].T.dot(inner.dot(basis[:,i])))

        basis[:,i] *= 1./newnorm

def range_generation(top, range_inner, source_inner, max_size=13, num_testvecs=10):
    testvecs = top.dot(scipy.random.normal(size=(top.shape[1],num_testvecs)))

    basis = np.zeros((range_inner.shape[0],0))
    remains = np.copy(top)
    opnorms = []
    testvecnorms = []
    for i in range(max_size):
        inners = remains.T.dot(range_inner.dot(remains))
        opnorm = np.sqrt(np.max(scipy.linalg.eigvalsh(inners, source_inner.todense())))
        opnorms.append(opnorm)
        testvec_norms = np.sqrt(np.diag(testvecs.T.dot(range_inner.dot(testvecs))))
        assert len(testvec_norms.shape) == 1
        testvecnorms.append(testvec_norms)

        # basis extension
        oldbasissize = basis.shape[1]
        basis = np.hstack((basis,top.dot(scipy.random.normal(size=(top.shape[1],1)))))
        gram_schmidt(basis, range_inner, oldbasissize)

        remains -= basis.dot(basis.T.dot(range_inner.dot(remains)))
        testvecs -= basis.dot(basis.T.dot(range_inner.dot(testvecs)))

    return opnorms, testvecnorms


def adaptive_range_generation(t_operator, range_inner, source_inner, num_testvecs, target_error):
    """
    @param t_operator   operator to approximate image of, given as numpy matrix
    @param range_inner  inner product in range space, given as numpy matrix
    @param source_inner inner product in source space, given as numpy matrix
    @param num_testvecs number of test vectors to use
    """
    testvecs = t_operator.dot(scipy.random.normal(size=(t_operator.shape[1],num_testvecs)))

    basis = np.zeros((range_inner.shape[0],0))
    xi = testlimit(target_error, num_testvecs)
    xi = testlimit(
        failure_tolerance=1e-15, 
        dim_S=t_operator.shape[1], 
        dim_R=t_operator.shape[0],
        num_testvecs=num_testvecs,
        target_error=target_error,
        lambda_min=1e99 # todo
        )
    while np.sqrt(np.max(testvecs.T.dot(range_inner.dot(testvecs)))) > xi:
        basis = np.hstack((basis,t_operator.dot(scipy.random.normal(size=(t_operator.shape[1],1)))))
        gram_schmidt(basis, range_inner, basis.shape[1]-1)
        testvecs -= basis.dot(basis.T.dot(range_inner.dot(testvecs)))

    return basis

