import numpy as np
import scipy.special
from pymor.vectorarrays.numpy import NumpyVectorArray
from pymor.algorithms.basisextension import gram_schmidt_basis_extension
from pymor.operators.constructions import induced_norm

def testlimit(failure_tolerance, dim_S, dim_R, num_testvecs, target_error, lambda_min):
    """
    failure_tolerance:  maximum probability for failure of algorithm
    dim_S: dimension of source space
    dim_R: dimension of range space
    num_testvecs: number of test vectors used
    target_error: desired maximal norm of tested operator
    lambda_min: smallest eigenvalue of matrix of inner product in source space
    """
    return scipy.special.erfinv( (failure_tolerance / min(dim_S, dim_R))**(1./num_testvecs)) * target_error * np.sqrt(2. * lambda_min)

def basis_extension(basis, product, transfer):
    n_source = transfer.source.dim
    random_vector = NumpyVectorArray(
        np.random.normal(size=(1, n_source)))
    range_vector = transfer.apply(random_vector)
    
    gram_schmidt_basis_extension(basis, range_vector, product=product, copy_basis=False)

# 0.0026
# 0.00238

def orthogonal_part(basis, product, U):
    projections = product.apply2(basis, U)
    projected_part = basis.lincomb(projections.T)
    orthogonal_part = U - projected_part

    # reiterate, just to be save:
    projections = product.apply2(basis, orthogonal_part)
    projected_part = basis.lincomb(projections.T)
    orthogonal_part = orthogonal_part - projected_part

    return orthogonal_part

def better_basis_extension(basis, product, transfer):
    n_source = transfer.source.dim
    random_vector = NumpyVectorArray(
        np.random.normal(size=(30, n_source)))
    range_vector = transfer.apply(random_vector)

    orthogonal = orthogonal_part(basis, product, range_vector)
    norms = induced_norm(product)(orthogonal)
    maxnormat = np.argmax(norms)
    maxvec = orthogonal.copy(ind=maxnormat)
    
    gram_schmidt_basis_extension(basis, maxvec, product=product, copy_basis=False)

# mit 10:
# 0.00123
# 0.00126
# mit 10, mit better inner product:
# 0.001182

# mit 30:
# 0.00107

def even_better_basis_extension(basis, product, transfer):
    n_source = transfer.source.dim
    random_vector = NumpyVectorArray(
        np.random.normal(size=(1000, n_source)))
    range_vector = transfer.apply(random_vector)

    orthogonal = orthogonal_part(basis, product, range_vector)

    from pymor.algorithms.pod import pod
    vecs, vals = pod(orthogonal, product=product, modes=1)
    assert len(vecs) == 1

    gram_schmidt_basis_extension(basis, vecs, product=product, copy_basis=False)

# mit 30
# 0.000676
# 0.00069
# better inner product:
# 0.000541

# mit 100
# 0.00034

# mit 1000
# 0.00024
# with better inner product
# 0.000209

# optimal basis
# bsize 10 0.00020452520
# bsize  4 0.00379465854932

# optimal basis mit dualraums inner product
# bsize 10 0.000223847245805
# bsize  4 0.00377941049424

# new problem
# optimal basis 0.000179189107462
# better basis  0.000179420489188
