import time
import scipy
import matplotlib.pyplot as plt

from pymor.operators.constructions import induced_norm
from pymor.operators.numpy import NumpyMatrixOperator
from pymor.vectorarrays.numpy import NumpyVectorArray
from pymor.discretizations.basic import StationaryDiscretization

from problems import *
from algorithms import orthogonal_part, testlimit
from localize_problem import localize_problem

from basis_generation import get_random_bases
from calculate_csi import calculate_csi
from calculate_lambda_min import calculate_lambda_min
from calculate_cq import calculate_cq

from lrb_operator_projection import LRBOperatorProjection

import simdb.run as sdb

def calculate_global_approximations(gq,lq, bases, maxnorms, tolerances, num_testvecs, operator_reductor):
    u = gq["d"].solve()
    l = gq["localizer"]
    coarse_grid_resolution = gq["coarse_grid_resolution"]

    # the constants
    C_2_by_diam = np.sqrt(2)*coarse_grid_resolution
    C_1 = 1.
    M = 4.
    M_star = 16.

    sdb.append_values(global_energy_norm=induced_norm(gq["energy_0_product"])(u)[0])

    basis_sizes = []
    errors = []

    failure_tolerance = 1e-15
    
    local_failure_tolerance = failure_tolerance / ( (coarse_grid_resolution -1)**2 )
    for tolerance in tolerances:
        for ypos, xpos in np.ndindex((coarse_grid_resolution-1,coarse_grid_resolution-1)):
            ldict = lq[xpos, ypos]
            eps_i = tolerance * (2.*M*M_star)**(-0.5) * (C_2_by_diam**2 * ldict["c_q"]**2 + C_1**2)**(-0.5)
            max_op_norm = eps_i / ldict["c_si"]
            testlimit_zeta = testlimit(
                failure_tolerance=local_failure_tolerance,
                dim_S=ldict["transfer_operator"].source.dim,
                dim_R=ldict["transfer_operator"].range.dim,
                num_testvecs=num_testvecs,
                target_error=max_op_norm,
                lambda_min=ldict["lambda_min"]
                )

            space = ldict["pou_range_space"]
            num_vecs = np.count_nonzero(maxnorms[space] > testlimit_zeta)
            if ldict["omega_has_dirichlet"]:
                # we have the u_f
                num_vecs = num_vecs + 1
            else:
                # we have u_f and constant in addition
                num_vecs = num_vecs + 2

            num_vecs = min(num_vecs, len(bases[space]))
            basis = bases[space].copy(ind=range(num_vecs))

            operator_reductor.set_range_basis(space, basis)
            operator_reductor.set_source_basis(space, basis)

        gfem_op = operator_reductor.get_reduced_operator()
        gfem_rhs = operator_reductor.get_reduced_rhs()
        gfem_discretization = StationaryDiscretization(gfem_op, gfem_rhs, cache_region=None)
            
        try:
            gfem_solution = gfem_discretization.solve()
        except:
            break
        reconstructed_solution = operator_reductor.reconstruct_source(gfem_solution)
        r = reconstructed_solution
        error = induced_norm(gq["energy_0_product"])(u-r)[0]
        print "error is {}".format(error)

        basis_sizes.append(gfem_op.source.dim)
        errors.append(error)
        
    sdb.append_values(basis_sizes=basis_sizes, errors=errors)
    sdb.flush()

    show_image = False
    if show_image:
        plt.xscale("log")
        plt.yscale("log")
        plt.gca().invert_xaxis()
        plt.plot(tolerances, errors)
        plt.show()
        import pdb
        pdb.set_trace()


for problem in ["h", "poisson"]:
    experimentname = "global_approximations"
    
    resolution = 200
    coarse_grid_resolution = 10
    num_testvecs = 20
    max_basis_size = 80

    sdb.new_dataset(experimentname,
                    problem=problem,
                    coarse_grid_resolution=coarse_grid_resolution,
                    resolution=resolution,
                    num_testvecs=num_testvecs,
                    max_basis_size=max_basis_size)

    if problem == "h":
        p = h_problem()
    elif problem == "poisson":
        p = poisson_problem()
    else:
        raise "Jo!"

    gq, lq = localize_problem(p, coarse_grid_resolution, resolution)

    calculate_csi(gq, lq)
    calculate_lambda_min(gq, lq)
    calculate_cq(gq, lq)


    tols = np.logspace(4, -4, 50)
    sdb.add_values(tolerances=tols)

    iterations = 1000

    bases, maxnorms = get_random_bases(gq, lq, 1, 1)

    op = gq["d"].operator.assemble()
    rhs = gq["d"].rhs.assemble()
    localizer = gq["localizer"]
    spaces = list([l["pou_range_space"] for temp in lq for l in temp])
    operator_reductor = LRBOperatorProjection(op, rhs, localizer, spaces, bases, spaces, bases)

    lasttime = time.time()
    for _ in range(iterations):
        thistime = time.time()
        print("duration for last iteration {}".format(thistime - lasttime))
        lasttime = thistime
        
        bases, maxnorms = get_random_bases(gq, lq, max_basis_size, num_testvecs)
        calculate_global_approximations(gq,lq, bases, maxnorms, tols, num_testvecs=num_testvecs, operator_reductor=operator_reductor)


