from problems import *
from localize_problem import localize_problem
from basis_generation import get_random_bases
from lrb_operator_projection import LRBOperatorProjection
from pymor.discretizations.basic import StationaryDiscretization
from pymor.operators.constructions import induced_norm
from pymor.vectorarrays.numpy import NumpyVectorArray
from algorithms import orthogonal_part

import simdb.run as sdb
from calculate_cq import calculate_cq

def run_experiment(p):
    coarse_grid_resolution = 10
    resolution = 200
    gq, lq = localize_problem(p, coarse_grid_resolution, resolution)
    #calculate_cq(gq,lq)
    u = gq["d"].solve()
    l = gq["localizer"]

    max_basis_size = 80
    num_testvecs = 20

    sdb.new_dataset("local_global_error_fixed_basis_size",
                    problem=p.name,
                    coarse_grid_resolution=coarse_grid_resolution,
                    resolution=resolution,
                    num_testvecs=num_testvecs,
                    max_basis_size=max_basis_size
                    )

    norm_u = induced_norm(gq["energy_0_product"])(u)[0]
    sdb.append_values(norm_u=norm_u)
    
    bases, maxnorms = get_random_bases(gq,lq,1, 1)
    # the constants
    C_2_by_diam = np.sqrt(2)*coarse_grid_resolution
    C_1 = 1.
    M = 4.
    M_star = 16.

    op = gq["d"].operator.assemble()
    rhs = gq["d"].rhs.assemble()
    localizer = gq["localizer"]

    local_norms = [
        induced_norm(
            lq[xpos, ypos]["omega_star_energy_0_product"]
            )(
            l.localize_vector_array(u, lq[xpos, ypos]["omega_star_space"])
            )[0]
        for ypos, xpos in np.ndindex((coarse_grid_resolution-1,coarse_grid_resolution-1))]

    sdb.append_values(local_norms=local_norms)
    
    spaces = [lq[xpos,ypos]["pou_range_space"] for ypos, xpos in np.ndindex((coarse_grid_resolution-1,coarse_grid_resolution-1))]
    assert len(spaces) == 81
    operator_reductor = LRBOperatorProjection(op, rhs, localizer, spaces, bases, spaces, bases)
    
    allerrors = []
    num_samples = 100
    for _ in range(num_samples):
        sdb.flush()
        bases, maxnorms = get_random_bases(gq,lq,max_basis_size, num_testvecs)
        local_bestapproximations = []
        gfem_errors = []
        for i in range(max_basis_size):
            local_bestapproximations_this_basis_size = []
            for ypos, xpos in np.ndindex((coarse_grid_resolution-1,coarse_grid_resolution-1)):
                ldict = lq[xpos, ypos]

                lsol = l.localize_vector_array(u, ldict["range_space"])
                lenergy_product = ldict["range_energy_0_product"]

                space = ldict["range_space"]
                num_vecs = min(i,len(bases[space]))
                basis = bases[space].copy(ind=range(num_vecs))

                # remove u_f:
                lsol_minus_u_f = lsol - ldict["local_solution"]
                # remove constant part:
                if not ldict["omega_has_dirichlet"]:
                    constant_one = NumpyVectorArray(np.ones(lsol_minus_u_f.space.dim))
                    constant_one_normed = constant_one * (1./ induced_norm(ldict["range_l2_product"])(constant_one))
                    lsol_minus_u_f = orthogonal_part(constant_one_normed, ldict["range_l2_product"], lsol_minus_u_f)

                lsolorth = orthogonal_part(basis, lenergy_product, lsol_minus_u_f)
                lsolorth_norm = induced_norm(lenergy_product)(lsolorth)[0]

                local_bestapproximations_this_basis_size.append(lsolorth_norm)

                space = ldict["pou_range_space"]
                num_vecs = i
                if ldict["omega_has_dirichlet"]:
                    num_vecs = num_vecs + 1
                else:
                    num_vecs = num_vecs + 2

                num_vecs = min(num_vecs,len(bases[space]))
                basis = bases[space].copy(ind=range(num_vecs))

                operator_reductor.set_range_basis(space, basis)
                operator_reductor.set_source_basis(space, basis)

            local_bestapproximations.append(local_bestapproximations_this_basis_size)

            gfem_op = operator_reductor.get_reduced_operator()
            gfem_rhs = operator_reductor.get_reduced_rhs()
            gfem_discretization = StationaryDiscretization(gfem_op, gfem_rhs, cache_region=None)
            
            try:
                gfem_solution = gfem_discretization.solve()
            except:
                break
            reconstructed_solution = operator_reductor.reconstruct_source(gfem_solution)
            r = reconstructed_solution
            error = induced_norm(gq["energy_0_product"])(u-r)[0]
            gfem_errors.append(error)
            print "error is {}".format(error)


        sdb.append_values(local_bestapproximations=np.array(local_bestapproximations))
        sdb.append_values(gfem_errors=gfem_errors)

run_experiment(h_problem())
run_experiment(poisson_problem())
