import time
import numpy as np
import scipy
import sys
from transferoperator import transferoperator_interfaces
from algorithms import range_generation, operator_svd, adaptive_range_generation, adaptive_implicit_svd
import pickle
import math

import matplotlib.pyplot as plt
from progress.bar import Bar
def getbar():
    return Bar("sampling", suffix="%(index)d / %(max)d, time: %(elapsed_td)s - %(eta_td)s")

from pymor.functions.basic import GenericFunction, ConstantFunction
from pymor.domaindescriptions.boundarytypes import BoundaryType
from pymor.domaindescriptions.basic import RectDomain
from pymor.analyticalproblems.elliptic import EllipticProblem
from pymor.domaindiscretizers.default import discretize_domain_default
from pymor.grids.rect import RectGrid
from pymor.discretizers.elliptic import discretize_elliptic_cg
from pymor.parameters.functionals import ExpressionParameterFunctional

from assembly import *

from myarpack import svds as mysvds
scipy.sparse.linalg.svds = mysvds

import os
os.environ["OMP_NUM_THREADS"] = "1"

#superlu has faster forward/backward than umfpack
scipy.sparse.linalg.use_solver(useUmfpack=False)

num_calls = 0
num_adjcalls = 0

def get_implicit_top(h_inv,ysize, k):
    neumann_data = ConstantFunction(0., dim_domain=2)
    dirichlet_data = ConstantFunction(0., dim_domain=2)

    nx = 2*h_inv
    ny = ysize*h_inv

    def dirichletdof_f(x):
        return (
            ((x[0] <= -1.+1e-12) or (x[0] >= 1. - 1e-12))
            )

    diameter=1./h_inv*np.sqrt(2.)
    domain = RectDomain(
        domain=([-1., 0.], [1., ysize]),
    top=BoundaryType('neumann'), bottom=BoundaryType('neumann'))

    grid, bi = discretize_domain_default(domain, diameter=diameter, grid_type=RectGrid)


    problem = EllipticProblem(
        domain=domain,
        dirichlet_data=dirichlet_data,
        neumann_data=neumann_data,
        diffusion_functions=(ConstantFunction(1., dim_domain=2),), 
        reaction_functions=(ConstantFunction(1., dim_domain=2),), 
        reaction_functionals=(ExpressionParameterFunctional("-k**2", {"k": (1,)}),),
        rhs=ConstantFunction(0, dim_domain=2),
        )
    

    d, data = discretize_elliptic_cg(analytical_problem=problem, grid=grid, boundary_info=bi)

    dirichletdofs_frommesh = np.nonzero(map(dirichletdof_f, grid.centers(2)))[0]
    dirichletdofs = np.concatenate(([i*(nx+1) for i in range(ny+1)], [(i+1)*(nx+1)-1 for i in range(ny+1)]))
    assert set(dirichletdofs) == set(dirichletdofs_frommesh)

    slice_inner = assembly_l2_1d(ny, 1./ny)
    range_inner = slice_inner
    source_inner = scipy.sparse.bmat([[slice_inner,None],[None,slice_inner]], format="csc", dtype=np.float64)

    targetdofs = np.nonzero(map(lambda x: np.abs(x[0] - 0.) < 1e-12, grid.centers(2)))[0]
    remainingdofs = [x for x in range(d.solution_space.dim) if (x not in dirichletdofs and x not in targetdofs)]
    assert len(dirichletdofs) + len(targetdofs) + len(remainingdofs) == d.solution_space.dim

    innerdofs = np.concatenate((targetdofs, remainingdofs))

    A = d.operator.assemble(k)._matrix
    Aii = A[innerdofs,:][:,innerdofs]
    Aid = A[innerdofs,:][:,dirichletdofs]
    starttime = time.time()
    Aiifac = scipy.sparse.linalg.factorized(Aii)
    source_inner_fac = scipy.sparse.linalg.factorized(source_inner)
    print("factorization time was {}, unknowns {}".format(time.time() - starttime, Aii.shape[0]))

    def apply_top(array):
        global num_calls
        if len(array.shape) > 1:
            num_calls += array.shape[1]
            return -Aiifac(Aid.dot(array))[:len(targetdofs),:]
        else:
            num_calls += 1
            return -Aiifac(Aid.dot(array))[:len(targetdofs)]

    def apply_top_adj(array):
        global num_adjcalls
        num_vecs = 1
        if len(array.shape) > 1:
            num_adjcalls += array.shape[1]
            s1 = range_inner.dot(array)
            s2 = np.zeros((Aii.shape[0],array.shape[1]))
            s2[:len(targetdofs),:] = s1
            s3 = -Aiifac(s2)
            s4 = Aid.T.dot(s3)
            s5 = source_inner_fac(s4)
            return s5
        else:
            num_adjcalls += 1
            s1 = range_inner.dot(array)
            s2 = np.zeros((Aii.shape[0],))
            s2[:len(targetdofs)] = s1
            s3 = -Aiifac(s2)
            s4 = Aid.T.dot(s3)
            s5 = source_inner_fac(s4)
            return s5
            

    Lop = scipy.sparse.linalg.LinearOperator(
        dtype=np.float, 
        shape=(len(targetdofs), len(dirichletdofs)),
        matmat=apply_top,
        matvec=apply_top,
        rmatvec=apply_top_adj
        )

    return Lop, source_inner, range_inner

h_inv = 200
y_size = 8
num_samples = 100
tolerance = 1e-4

ex_comparison = False

implicit_top, source_inner, range_inner = get_implicit_top(h_inv,y_size,0)
if ex_comparison:
    explicit_top, source_inner, range_inner = transferoperator_interfaces(h_inv,2*h_inv,y_size*h_inv)
    explicit_adj = scipy.sparse.linalg.spsolve(source_inner, explicit_top.T.dot(range_inner.todense()))


    somevec = np.random.random((2*(y_size*h_inv + 1),1))
    assert np.linalg.norm(implicit_top.matvec(somevec) - explicit_top.dot(somevec)) < 1e-10
    somevec = np.random.random((y_size*h_inv + 1,1))
    assert np.linalg.norm(implicit_top.rmatvec(somevec) - explicit_adj.dot(somevec)) < 1e-10

print("doing microbenchmarks")
somevec = np.random.random((2*(y_size*h_inv + 1),1))
starttime = time.time()
for _ in range(num_samples):
    implicit_top.matvec(somevec)
endtime = time.time()
print("matvec time is {}".format((endtime - starttime)/num_samples))

somevec = np.random.random((1*(y_size*h_inv + 1),1))
starttime = time.time()
for _ in range(num_samples):
    implicit_top.rmatvec(somevec)
endtime = time.time()
print("rmatvec time is {}".format((endtime - starttime)/num_samples))

print("Starting adaptive basis generataion")
num_calls = 0
num_adjcalls = 0
starttime = time.time()
adaptivebasis = adaptive_range_generation(implicit_top, range_inner, source_inner, 20, tolerance)
endtime = time.time()
print("generated basis has size {}".format(adaptivebasis.shape[1]))
print("num calls is {}, num_adjcalls is {}".format(num_calls, num_adjcalls))
print("it took {} seconds".format(endtime - starttime))

def adaptive_arpack(top, tolerance):
    k = 10
    adresult = scipy.sparse.linalg.svds(implicit_top, k=k, return_singular_vectors="u")
    
    while np.min(adresult[1]) >= tolerance:
        k = int(math.ceil(k*2))
        adresult = scipy.sparse.linalg.svds(implicit_top, k=k,return_singular_vectors="u")

    return adresult
        
print("calculating adaptive basis using arpack")
num_calls = 0
num_adjcalls = 0
adaptivesize = adaptivebasis.shape[1]
starttime = time.time()
results = scipy.sparse.linalg.svds(implicit_top, k=adaptivesize, return_singular_vectors="u")
#results = adaptive_arpack(implicit_top, tolerance)
#results = adaptive_implicit_svd(implicit_top, tolerance)
endtime = time.time()
arpacksize=np.count_nonzero(results[1] >= tolerance)
print("generated basis has size {}".format(arpacksize))
print("num calls is {}, num_adjcalls is {}".format(num_calls, num_adjcalls))
print("it took {} seconds".format(endtime - starttime))
print("sparse calculated svals: {}".format(results[1]))
if ex_comparison:
    resulte = operator_svd(explicit_top, source_inner, range_inner)
    print("dense calculated svals: {}".format(resulte[1][:len(results[1])]))

