import numpy as np
from itertools import izip
import copy

from partitioner import build_subspaces, partition_any_grid
from localizer import NumpyLocalizer
from pymor.grids.subgrid import SubGrid
from pymor.grids.boundaryinfos import SubGridBoundaryInfo
from pymor.domaindescriptions.boundarytypes import BoundaryType
from my_discretize_elliptic_cg import discretize_elliptic_cg
from pymor.vectorarrays.numpy import NumpyVectorArray
from pymor.operators.numpy import NumpyMatrixOperator
from pymor.discretizations.basic import StationaryDiscretization
from pymor.operators.numpy import NumpyGenericOperator
from pymor.operators.constructions import induced_norm, LincombOperator
from pymor.operators.cg import L2ProductP1, DiffusionOperatorP1

from discrete_pou import localized_pou
from algorithms import orthogonal_part

def getsubgrid(grid, xpos, ypos, coarse_grid_resolution, xsize=2, ysize=2):
    assert 0 <= xpos <= coarse_grid_resolution - 2
    assert 0 <= ypos <= coarse_grid_resolution - 2

    xstep = float(grid.domain[1][0] - grid.domain[0][0])/coarse_grid_resolution
    ystep = float(grid.domain[1][1] - grid.domain[0][1])/coarse_grid_resolution

    xmin = grid.domain[0][0] + xpos*xstep
    xmax = xmin + xsize*xstep

    ymin = grid.domain[0][1] + ypos * ystep
    ymax = ymin + ysize*ystep

    def filter(elem):
        return (xmin <= elem[0] <= xmax) and (ymin <= elem[1] <= ymax)

    mask = map(filter, grid.centers(0))
    indices = np.nonzero(mask)[0]
    return SubGrid(grid, indices)

def generate_k_l2_product(p, grid, boundary_info):
    Li = [DiffusionOperatorP1(grid, boundary_info, diffusion_constant=0, name='boundary_part', dirichlet_clear_columns=True)]
    coefficients = [1.]

    # diffusion part
    if p.diffusion_functionals is not None:
        Li += [L2ProductP1(grid, boundary_info, coefficient_function=df, dirichlet_clear_diag=True,
                                 dirichlet_clear_columns=True,
                                 name='diffusion_{}'.format(i))
               for i, df in enumerate(p.diffusion_functions)]
        coefficients += list(p.diffusion_functionals)
    elif p.diffusion_functions is not None:
        assert len(p.diffusion_functions) == 1
        Li += [L2ProductP1(grid, boundary_info, coefficient_function=p.diffusion_functions[0],
                                 dirichlet_clear_diag=True, 
                                 dirichlet_clear_columns=True,
                                 name='diffusion')]
        coefficients.append(1.)
    else:
        assert(False)

    k_l2_product = LincombOperator(operators=Li, coefficients=coefficients, name='k_l2_0').assemble()
    return k_l2_product

def localize_problem(p, coarse_grid_resolution, fine_grid_resolution):
    global_quantities = {}
    global_quantities["coarse_grid_resolution"] = coarse_grid_resolution

    local_quantities = np.empty(shape=(coarse_grid_resolution-1, coarse_grid_resolution-1), dtype=np.dtype(object))

    # discretize problem
    diameter = 1./fine_grid_resolution
    d, data = discretize_elliptic_cg(p, diameter=diameter)

    grid = data["grid"]
    boundary_info = data["boundary_info"]

    dirichlet_dofs = data['boundary_info'].dirichlet_boundaries(2)
    for op in d.operator.operators:
        op.assemble()._matrix[dirichlet_dofs, dirichlet_dofs] *= 1e5
    d.rhs.assemble()._matrix[:, dirichlet_dofs] *= 1e5

    global_quantities["d"] = d
    global_quantities["data"] = data

    subspaces, subspaces_per_codim = build_subspaces(
        *partition_any_grid(grid, num_intervals=(coarse_grid_resolution, coarse_grid_resolution))
         )
    localizer = NumpyLocalizer(d.solution_space, subspaces['dofs'])
    global_quantities["localizer"] = localizer

    # some products
    full_h1_product = d.products["h1"].assemble()
    full_l2_product = d.products["l2"].assemble()
    full_energy_product = copy.deepcopy(d.operator.assemble())
    d_indices = np.nonzero(data["boundary_info"].mask(BoundaryType("dirichlet"), 2))[0]
    if len(d_indices):
        full_energy_product._matrix[d_indices,:] *= 0
        full_energy_product._matrix[:,d_indices] *= 0
        full_energy_product._matrix[d_indices, d_indices] = 1
        full_energy_product._matrix.eliminate_zeros()

    full_operator = d.operator.assemble()
    full_rhs = d.rhs.assemble()
    global_quantities["energy_0_product"] = full_energy_product

    fdict = localized_pou(subspaces, subspaces_per_codim, localizer, coarse_grid_resolution, grid)

    for xpos in range(coarse_grid_resolution-1):
        for ypos in range(coarse_grid_resolution-1):
            ldict = {}
            local_quantities[xpos, ypos] = ldict

            ldict["omega_star_has_dirichlet"] = True
            if xpos >= 2 and xpos <= coarse_grid_resolution -4:
                if ypos >= 2 and ypos <= coarse_grid_resolution -4:
                    ldict["omega_star_has_dirichlet"] = False
                
            ldict["omega_has_dirichlet"] = True
            if xpos >= 1 and xpos <= coarse_grid_resolution -3:
                if ypos >= 1 and ypos <= coarse_grid_resolution -3:
                    ldict["omega_has_dirichlet"] = False

            mysubgrid = getsubgrid(grid, xpos, ypos, coarse_grid_resolution)
            mysubbi = SubGridBoundaryInfo(mysubgrid, grid, data['boundary_info'], BoundaryType('neumann'))
            ld, ldata = discretize_elliptic_cg(p, grid=mysubgrid, boundary_info=mysubbi)

            # discretize omega-star
            xminext = max(0,xpos - 1)
            xsize = min(xpos + 3, coarse_grid_resolution - 2 + 3) - xminext
            yminext = max(0,ypos - 1)
            ysize = min(ypos + 3, coarse_grid_resolution - 2 + 3) - yminext
            mysubgridext = getsubgrid(grid, xminext, yminext, coarse_grid_resolution, xsize=xsize, ysize=ysize)
            mysubbiext = SubGridBoundaryInfo(mysubgridext, grid, data['boundary_info'], BoundaryType('neumann'))
            ldext, ldataext = discretize_elliptic_cg(p, grid=mysubgridext, boundary_info=mysubbiext)

            #CorrespondingSpace
            csid = subspaces_per_codim[2][ypos + xpos*(coarse_grid_resolution-1)]
            range_space = tuple(sorted(set(subspaces[csid]['env']) | set(subspaces[csid]['cenv'])))

            ldict["range_space"] = range_space

            omega_star_space = tuple(sorted(set(subspaces[csid]['xenv']) | set(subspaces[csid]['cxenv'])))
            ldict["omega_star_space"] = omega_star_space
            
            training_space = subspaces[csid]['xenv']
            ldict["training_space"] = training_space

            source_space = subspaces[csid]['cxenv']
            ldict["source_space"] = source_space

            # clear columns on dirichlet dofs
            bilifo = copy.deepcopy(ld.operator.assemble())
            k_l2_product = generate_k_l2_product(p, mysubgrid, mysubbi)
            d_indices = np.nonzero(mysubbi.mask(BoundaryType("dirichlet"), 2))[0]

            if len(d_indices):
                bilifo._matrix[d_indices,:] *= 0
                bilifo._matrix[:,d_indices] *= 0
                bilifo._matrix[d_indices, d_indices] = 100
                bilifo._matrix.eliminate_zeros()
        
            # do index conversion between localizations
            ndofs = len(localizer.join_spaces(range_space))
            global_dofnrs = -100000000* np.ones(shape=(d.solution_space.dim,))
            global_dofnrs[ldata['grid'].parent_indices(2)] = np.array(range(ndofs))
            lvec = localizer.localize_vector_array(NumpyVectorArray(global_dofnrs), range_space).data[0]

            bilifo = NumpyMatrixOperator(bilifo._matrix[:,lvec][lvec,:])
            k_l2_product = NumpyMatrixOperator(k_l2_product._matrix[:,lvec][lvec,:])
            ldict["range_energy_0_product"] = bilifo
            ldict["k_l2_product"] = k_l2_product

            l2 = copy.deepcopy(ld.products["l2"].assemble())
            omega_l2_product = NumpyMatrixOperator(l2._matrix[:,lvec][lvec,:])
            tempvec = k_l2_product.apply(NumpyVectorArray(np.ones(ndofs)))
            fixed_inner_product = NumpyMatrixOperator(1000000*tempvec.data.T.dot(tempvec.data) + bilifo._matrix)
            ldict["range_fixed_energy_0_product"] = fixed_inner_product

            # omega-star-energy-product

            # clear columns on dirichlet dofs
            bilifoext = copy.deepcopy(ldext.operator.assemble())
            d_indicesext = np.nonzero(mysubbiext.mask(BoundaryType("dirichlet"), 2))[0]

            if len(d_indicesext):
                bilifoext._matrix[d_indicesext,:] *= 0
                bilifoext._matrix[:,d_indicesext] *= 0
                bilifoext._matrix[d_indicesext, d_indicesext] = 1
                bilifoext._matrix.eliminate_zeros()
        
            # do index conversion between localizations
            ndofsext = len(localizer.join_spaces(omega_star_space))
            global_dofnrsext = -100000000* np.ones(shape=(d.solution_space.dim,))
            global_dofnrsext[ldataext['grid'].parent_indices(2)] = np.array(range(ndofsext))
            lvecext = localizer.localize_vector_array(NumpyVectorArray(global_dofnrsext), omega_star_space).data[0]

            bilifoext = NumpyMatrixOperator(bilifoext._matrix[:,lvecext][lvecext,:])

            ldict["omega_star_energy_0_product"] = bilifoext
            
            l2ext = copy.deepcopy(ldext.products["l2"].assemble())
            omega_star_l2_product = NumpyMatrixOperator(l2ext._matrix[:,lvecext][lvecext,:])
            tempvec = omega_star_l2_product.apply(NumpyVectorArray(np.ones(ndofsext)))
            fixed_inner_product = NumpyMatrixOperator(1000*tempvec.data.T.dot(tempvec.data) + bilifoext._matrix)
            ldict["omega_star_fixed_energy_product"] = fixed_inner_product
            ldict["omega_star_l2_product"] = omega_star_l2_product
            

            # range l2 product
            l2mat = ld.products["l2"].assemble()
            l2mat = NumpyMatrixOperator(l2mat._matrix[:,lvec][lvec,:])
            ldict["range_l2_product"] = l2mat

            # source product
            lproduct = localizer.localize_operator(full_l2_product, source_space, source_space)
            lmat = lproduct._matrix.tocoo()
            lmat.data = np.array([4./6.*diameter if (row == col) else diameter/6. for row, col in izip(lmat.row, lmat.col)])
            ldict["source_l2_product"] = NumpyMatrixOperator(lmat.tocsc())
        
            # transfer operator
            local_op = localizer.localize_operator(full_operator, training_space, training_space)
            local_rhs = localizer.localize_operator(full_rhs, None, training_space)
            local_d = StationaryDiscretization(local_op, local_rhs, cache_region=None)
            local_solution = local_d.solve()
            local_solution = localizer.to_space(local_solution, training_space, range_space)
            ldict["local_solution"] = local_solution
        
            cavesize = len(localizer.join_spaces(source_space))
            rangesize = len(localizer.join_spaces(range_space))
            rhsop = localizer.localize_operator(full_operator, training_space, source_space)

            def create_transfer_op(localizer, local_op, rhsop, training_space, range_space, range_l2_product, remove_constant):
                def transfer(va):
                    range_solution = localizer.to_space(
                        local_op.apply_inverse(
                            - rhsop.apply(NumpyVectorArray(va))
                            ), training_space, range_space
                        )
                    # remove constant part:
                    if remove_constant:
                        constant_one = NumpyVectorArray(np.ones(range_solution.space.dim))
                        constant_one_normed = constant_one * (1./ induced_norm(range_l2_product)(constant_one)[0])
                        o_range_solution = orthogonal_part(constant_one_normed, range_l2_product, range_solution)
                        return o_range_solution.data
                    else:
                        return range_solution.data

                return NumpyGenericOperator(transfer, cavesize, rangesize, linear=True)
        
            transop = create_transfer_op(localizer, local_op, rhsop, training_space, range_space, range_l2_product=ldict["range_l2_product"], remove_constant=not ldict["omega_has_dirichlet"])
            ldict["transfer_operator"] = transop

            #pou stuff
            pou_range_space = subspaces[csid]['env']
            ldict["pou_range_space"] = pou_range_space
            pou_rangesize = len(localizer.join_spaces(pou_range_space))

            def create_pou_op(range_space, pou_range_space):
                def pou(va):
                    va = NumpyVectorArray(va)
                    val = localizer.to_space(va, range_space, pou_range_space)
                    poufunction = fdict[pou_range_space]
                    val = poufunction(val)
                    return val.data

                return NumpyGenericOperator(pou, rangesize, pou_rangesize, linear=True)
            pouop = create_pou_op(range_space, pou_range_space)
            ldict["pou_operator"] = pouop

            # pou range h1 product
            ldict["pou_range_h1_product"] = localizer.localize_operator(full_h1_product, pou_range_space, pou_range_space)

            
    return global_quantities, local_quantities
