import time
import numpy as np
import scipy
import sys
from algorithms import range_generation, operator_svd, operator_svd2
import pickle
import matplotlib.pyplot as plt
from progress.bar import Bar
def getbar():
    return Bar("sampling", suffix="%(index)d / %(max)d, time: %(elapsed_td)s - %(eta_td)s")

from pymor.functions.basic import GenericFunction, ConstantFunction
from pymor.domaindescriptions.boundarytypes import BoundaryType
from pymor.domaindescriptions.basic import RectDomain
from pymor.analyticalproblems.elliptic import EllipticProblem
from pymor.domaindiscretizers.default import discretize_domain_default
from pymor.grids.rect import RectGrid
from pymor.discretizers.elliptic import discretize_elliptic_cg
from pymor.parameters.functionals import ExpressionParameterFunctional

from assembly import *

neumann_data = ConstantFunction(0., dim_domain=2)
dirichlet_data = ConstantFunction(0., dim_domain=2)

ny = 160
nx = 2*ny
domain = RectDomain(
    domain=([-1., 0.], [1., 1.]),
    top=BoundaryType('neumann'), bottom=BoundaryType('neumann'))

grid, bi = discretize_domain_default(domain, diameter=1./ny*np.sqrt(2.), grid_type=RectGrid)


problem = EllipticProblem(
    domain=domain,
    dirichlet_data=dirichlet_data,
    neumann_data=neumann_data,
    diffusion_functions=(ConstantFunction(1., dim_domain=2),), 
    reaction_functions=(ConstantFunction(1., dim_domain=2),), 
    reaction_functionals=(ExpressionParameterFunctional("-k**2", {"k": (1,)}),),
    rhs=ConstantFunction(0, dim_domain=2),
    )


d, data = discretize_elliptic_cg(analytical_problem=problem, grid=grid, boundary_info=bi)

def dirichletdof_f(x):
    return (
        ((x[0] <= -1.+1e-12) or (x[0] >= 1. - 1e-12))
        )

dirichletdofs_frommesh = np.nonzero(map(dirichletdof_f, grid.centers(2)))[0]
dirichletdofs = np.concatenate(([i*(nx+1) for i in range(ny+1)], [(i+1)*(nx+1)-1 for i in range(ny+1)]))
assert set(dirichletdofs) == set(dirichletdofs_frommesh)

targetdofs = np.nonzero(map(lambda x: np.abs(x[0] - 0.) < 1e-12, grid.centers(2)))[0]
remainingdofs = [x for x in range(d.solution_space.dim) if (x not in dirichletdofs and x not in targetdofs)]
assert len(dirichletdofs) + len(targetdofs) + len(remainingdofs) == d.solution_space.dim

innerdofs = np.concatenate((targetdofs, remainingdofs))

slice_inner = assembly_l2_1d(ny, 1./ny)
range_inner = slice_inner
source_inner = scipy.sparse.bmat([[slice_inner,None],[None,slice_inner]], format="csc", dtype=np.float64)

maxs = []
ks = list(np.linspace(0, 50, 201))
svals = {}
plt.yscale("log")
for k in getbar().iter(ks):
    A = d.operator.assemble(k)._matrix
    Aii = A[innerdofs,:][:,innerdofs]
    Aid = A[innerdofs,:][:,dirichletdofs]
    Top = -scipy.sparse.linalg.spsolve(Aii, Aid)[:len(targetdofs), :].todense()
    u, s, v = operator_svd(Top, source_inner, range_inner)
    plt.plot(s)
    if False:
        plt.yscale("log")
        plt.plot(s)
        plt.plot(s2)
        plt.plot(s3)
        plt.show()

    svals[k] = s
    maxs.append(np.max(s))

plt.show()

datout = open("helmholtz_3d.dat", "w")
for i in range(25):
    for k in ks:
        datout.write("{} {} {}\n".format(i, k,np.log10(svals[k][i])))
    datout.write("\n")

datout.close()

datout2 = open("helmholtz_selected.dat", "w")
for i in range(len(targetdofs)):
    datout2.write(str(i) + " " + " ".join([str(svals[k][i]) for k in [0, 10, 20, 30, 40, 50]]) + "\n")

datout2.close()

plt.show()

plt.plot(maxs)
plt.show()

plt.yscale("linear")
for i in range(10):
    plt.plot(u[:,i])
plt.show()

plt.yscale("linear")
for i in range(10):
    plt.plot(v[:,i])
plt.show()

images = d.solution_space.empty()
for i in range(20):
    U = d.solution_space.zeros()
    uii = -scipy.sparse.linalg.spsolve(Aii,Aid.dot(v[:,i].real))
    U.data[0,innerdofs] = uii
    U.data[0,dirichletdofs] = v[:,i].real
    images.append(U)

d.visualize(images, rescale_colorbars=True)

k = 30
A = d.operator.assemble(k)._matrix
Aii = A[innerdofs,:][:,innerdofs]
Aid = A[innerdofs,:][:,dirichletdofs]
Top = -scipy.sparse.linalg.spsolve(Aii, Aid)[:len(targetdofs), :].todense()

# calculating 10 hours:
samples = 100000
operatornorms = []
testvecnorms = []
num_testvecs = 100

resultsvd = {}
resultsvd["source_condition"] = (np.min(scipy.sparse.linalg.eigs(source_inner, return_eigenvectors=False, which="SM")).real,
                                 np.max(scipy.sparse.linalg.eigs(source_inner, return_eigenvectors=False, which="LM")).real)
resultsvd["range_condition"] = (np.min(scipy.sparse.linalg.eigs(range_inner, return_eigenvectors=False, which="SM")).real,
                                 np.max(scipy.sparse.linalg.eigs(range_inner, return_eigenvectors=False, which="LM")).real)
u, s, v = operator_svd(Top, source_inner, range_inner)
resultsvd["svals"] = s
pickle.dump(resultsvd, open("svddecay.pickle","w"))

for i in getbar().iter(range(samples)):

    opnorms, tvnorms = range_generation(Top, range_inner, source_inner, num_testvecs=num_testvecs, max_size=25)
    operatornorms.append(opnorms)
    testvecnorms.append(tvnorms)

result = {}
result["operatornorms"] = np.array(operatornorms)
result["testvecnorms"] = np.array(testvecnorms)
result["lambdamin"] = np.min(scipy.sparse.linalg.eigs(source_inner, return_eigenvectors=False, which="SM")).real
result["lambdamax"] = np.max(scipy.sparse.linalg.eigs(source_inner, return_eigenvectors=False, which="LM")).real
result["dim_source"] = source_inner.shape[0]
result["dim_range"] = range_inner.shape[0]

pickle.dump(result, open("experiment_helmholtz_algo_convergence_result.pickle","w"))
