#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
@date: 2018-Today
@author: jasper.bathmann@ufz.de
"""
import vtk
import numpy as np
import argparse


meshName = "Ton-Nord_bottom_quad.vtu" #
#meshName = "Ton-Nord_top_quad.vtu" #_quad

mesh_type = "quadratic"# "linear" 

eps = .1
        

# Readout data
meshReader = vtk.vtkXMLUnstructuredGridReader()
meshReader.SetFileName(meshName)
meshReader.Update()
grid = meshReader.GetOutput()
(EAST_X, WEST_X, 
 NORTH_Y, SOUTH_Y, 
 TOP_Z, BOTTOM_Z) = grid.GetBounds()
cells = grid.GetCells()
points = grid.GetPoints()
pts = vtk.vtkIdList()
cells.InitTraversal()
point_ids_of_removed_cells = []
point_ids_of_kept_cells = []

bulk_node_ids  = grid.GetPointData().GetArray("bulk_node_ids")
mat_ids = grid.GetCellData().GetArray("MaterialIDs")
bulk_face_ids = grid.GetCellData().GetArray("bulk_face_ids")
bulk_element_ids  = grid.GetCellData().GetArray("bulk_element_ids")

# Defining n_nodes/cell for mesh type
if mesh_type == "linear":
    n_nodes = 4
    cell_type_int = 10
elif mesh_type == "quadratic":
    n_nodes = 8
    cell_type_int = 23
    
def createCell(mesh_type):
    if mesh_type == "linear":
        cell = vtk.vtkQuad()
    elif mesh_type == "quadratic":
        cell = vtk.vtkQuadraticQuad()
    return cell

# Find indices at planes
remove_indices = []
for i in range(grid.GetPoints().GetNumberOfPoints()):
    x,y,z = grid.GetPoints().GetPoint(i)

    if np.array(
            np.abs(
                np.array([x,x,y,y]) -
                np.array([EAST_X, WEST_X,
                          NORTH_Y, SOUTH_Y])
            ) < 
            eps
        ).any():
        remove_indices.append(i)

# Prepare new Cell and point arrays
new_cells = vtk.vtkCellArray()
new_points = vtk.vtkPoints()
new_cell_data = vtk.vtkCellData()

new_bulk_node_ids = vtk.vtkDataArray.CreateDataArray(
                    vtk.VTK_UNSIGNED_LONG)
new_bulk_node_ids.SetName("bulk_node_ids")
new_mat_ids = vtk.vtkDataArray.CreateDataArray(
                    vtk.VTK_INT)
new_mat_ids.SetName("MaterialIDs")
new_bulk_face_ids = vtk.vtkDataArray.CreateDataArray(
                    vtk.VTK_UNSIGNED_LONG) 
new_bulk_face_ids.SetName("bulk_face_ids")
new_bulk_element_ids = vtk.vtkDataArray.CreateDataArray(
                    vtk.VTK_UNSIGNED_LONG)
new_bulk_element_ids.SetName("bulk_element_ids")

# Iteration over all cells
for i in range(cells.GetNumberOfCells()):
    cells.GetNextCell(pts)
    bools = 0
    ids_o = []

    # Check if all 4 of the cell points are in remove-list
    for j in range(n_nodes):
        id0_o = pts.GetId(j)
        ids_o.append(id0_o)
        if (id0_o in remove_indices): bools+=1
    # If so, safe id of points associated to removed cell
    if(bools==n_nodes):
        for id_i in ids_o:
            if not (id_i in point_ids_of_removed_cells):
                point_ids_of_removed_cells.append(id_i)
    # Else, resample cell and safe id of points associated to resampled cell
    else:
        cell = createCell(mesh_type)
        for j in range(n_nodes):
            cell.GetPointIds().SetId(j, ids_o[j])
        new_cells.InsertNextCell(cell)
        for id_i in ids_o:
            if not (id_i in point_ids_of_kept_cells):
                point_ids_of_kept_cells.append(id_i)
        # Resample data from old boundary mesh
        new_mat_ids.InsertNextTuple1(int(mat_ids.GetTuple(i)[0]))
        new_bulk_face_ids.InsertNextTuple1(int(bulk_face_ids.GetTuple(i)[0]))
        new_bulk_element_ids.InsertNextTuple1(int(bulk_element_ids.GetTuple(i)[0]))

# Initiate lists for index transformation
old_point_index, new_point_index = [], []
j = 0
for i in range(grid.GetNumberOfPoints()):
    if i in point_ids_of_kept_cells:
        # Insert new points and fill lookup lists for index transformation
        new_bulk_node_ids.InsertNextTuple1(int(bulk_node_ids.GetTuple(i)[0]))
        new_points.InsertNextPoint(points.GetPoint(i))
        old_point_index.append(i)
        new_point_index.append(j)
        j+=1

# Rewrite list of indices for all cells
tmp_cells=vtk.vtkCellArray()
new_cells.InitTraversal()
while new_cells.GetNextCell(pts):
    cell = createCell(mesh_type)
    for j in range(n_nodes):
        old_id = pts.GetId(j)
        index = old_point_index.index(old_id)
        cell.GetPointIds().SetId(j, new_point_index[index])
    tmp_cells.InsertNextCell(cell)
new_cells = tmp_cells


# Set new data arrays, cells and points
grid.SetPoints(new_points)
grid.GetPointData().RemoveArray("bulk_node_ids")
grid.GetPointData().AddArray(new_bulk_node_ids)
grid.GetCellData().RemoveArray("MaterialIDs")
grid.GetCellData().AddArray(new_mat_ids)
grid.GetCellData().RemoveArray("bulk_face_ids")
grid.GetCellData().AddArray(new_bulk_face_ids)
grid.GetCellData().RemoveArray("bulk_element_ids")
grid.GetCellData().AddArray(new_bulk_element_ids)
grid.SetCells(cell_type_int, new_cells)

# Write output
writer = vtk.vtkXMLUnstructuredGridWriter()
writer.SetFileName("resampled_"+meshName)
writer.SetInputData(grid)
writer.Write()
