import numpy as np
import sys
import time
import pickle
import matplotlib.pyplot as plt
import scipy
import scipy.special
from algorithms import testlimit

def basis_sizes(testvecnorms, tolerance, num_testvecs):
    assert num_testvecs <= testvecnorms.shape[2]
    maxtestvecnorms = np.max(testvecnorms[:,:,:num_testvecs], axis=2)
    result = np.sum(maxtestvecnorms > tolerance, axis=1)
    if np.any(result == testvecnorms.shape[1]):
        raise 
    return result

last = time.time()
result = pickle.load(file("experiment_helmholtz_algo_convergence_result.pickle","r"))
print("loaded input file in {}".format(time.time() - last))


plt.yscale("log")
plt.xscale("log")

num_testvecs_array = [3,5,10,20,40,80]

projectionerrors = result["operatornorms"]
testvecnorms = result["testvecnorms"]
tolerances = np.logspace(4, -10, 50)

all_max_errors = []
for num_testvecs in num_testvecs_array:
    print("num testvecs ", num_testvecs)
    all_errors = []
    for tolerance in tolerances:
        sys.stdout.write(".")
        sys.stdout.flush()
        limit = testlimit(
            failure_tolerance=1e-15,
            dim_S=result["dim_source"],
            dim_R=result["dim_range"],
            num_testvecs=num_testvecs,
            target_error=tolerance,
            lambda_min=result["lambdamin"]
            )

        try:
            bs = basis_sizes(testvecnorms, limit, num_testvecs=num_testvecs)
            errors = projectionerrors[range(len(projectionerrors)),bs]
        except:
            errors = np.array([np.float("nan") for _ in range(len(projectionerrors))])
        
        all_errors.append(errors)

    all_errors = np.array(all_errors)
    max_errors = np.max(all_errors, axis=1)
    all_max_errors.append(max_errors)
    plt.plot(tolerances[:len(max_errors)], max_errors, label=str(num_testvecs))


plt.plot(tolerances, tolerances, label="1")
#plt.plot(tolerances, tolerances * 1e-2, label="1e-2")
#plt.plot(tolerances, tolerances * 1e-4, label="1e-4")
plt.gca().invert_xaxis()

plt.legend()
plt.show()


alldata = np.vstack([tolerances,] + all_max_errors)
open("helmholtz_adaptive_num_testvecs.dat", "w").writelines(
    [" ".join(map(str, v)) + "\n" for v in alldata.T]
)
