import os

import numpy as np
import pandas as pd
import meshio # paraview.save_meshio need meshio
from sys import argv, exit

from os.path import exists

from vtk import vtkXMLUnstructuredGridReader
from vtk.util import numpy_support
import pyvista as pv
import pathlib

gravitation = 'z'
glaciation = 'y'


def get_bounds(input_file):
    r = vtkXMLUnstructuredGridReader()
    r.SetFileName(input_file)
    r.Update()
    p = r.GetOutput().GetPoints()
    n = numpy_support.vtk_to_numpy(p.GetData())
    return [(n[:, axis].min(), n[:, axis].max()) for axis in range(0, 3, 1)]


def height_in_folder(folder):
    def height(row):
        txt = "{folder}{layer_id:02d}_{model_unit}.vtu"

        input_file = txt.format(layer_id=row['layer_id'],
                                model_unit=row['model_unit'], folder=folder)
        file_exists = exists(input_file)
        if file_exists is False:
            print(input_file, 'does not exist.')
            exit(1)

        r = vtkXMLUnstructuredGridReader()
        r.SetFileName(input_file)
        r.Update()
        p = r.GetOutput().GetPoints()
        n = numpy_support.vtk_to_numpy(p.GetData())
        return n[:, 2].mean()

    return height


def model_layer_parameter(layer_set_id, layer_sets_csv, parameters_file):
    dfs = pd.read_csv(layer_sets_csv)
    dfs = dfs[dfs["set_id"] == layer_set_id]
    if len(dfs) == 0:
        raise Exception(f'no model defined with {layer_set_id=}')
    dfm = pd.read_csv(parameters_file, delimiter=",")
    model_df = pd.merge(dfs, dfm, left_on=['layer_id'],
                        how='left', right_on=['layer_id'])
    return model_df


def create_layers(z_bounds, num_layers, layer_material_ids):
    layers_no_mat = \
        [np.linspace(a_bound, b_bound, num=num_slice + 1) for
         a_bound, b_bound, num_slice in zip(z_bounds, z_bounds[1:], num_layers)]
    return list(zip(layers_no_mat, layer_material_ids))


def centered_range(bounds, step):
    center = np.mean(bounds)
    # To ensure equal model sizes for a convergence study, we find the largest
    # 2^n * step which is below an upper resolution limit (assumed to be the
    # max. resolution in the study)
    max_res =  500.
    max_step = step * (2 **  int(np.log2(max_res / step)))
    # With that max step size we set new bounds
    l_bound = center - max_step * ((center - bounds[0]) // max_step)
    r_bound = center + max_step * ((bounds[1] - center) // max_step)
    
    left = np.arange(center, l_bound - step, -step)
    right = np.arange(center, r_bound + step, step)
    return np.unique(np.concatenate((left, [center], right)))


def create_structured_grid(axis_bounds, axis_res, axis2_range, m, rank):
    axis0_range = centered_range(axis_bounds[0], axis_res)
    axis1_range = centered_range(axis_bounds[1], axis_res)
    AXIS1, AXIS2 = np.meshgrid(axis1_range, axis2_range)
    AXIS0 = 0 * (AXIS1 + AXIS2)

    if rank == 2:
        orientation = 'xy'
        if orientation == 'yz':
            X, Y, Z = (AXIS0, AXIS1, AXIS2)
        elif orientation == 'xy':
            X, Y, Z = (AXIS1, AXIS2, AXIS0)
    elif rank == 3:
        X, Y, Z = np.meshgrid(axis0_range, axis1_range, axis2_range)
    else:
        raise Exception(f"{rank=}, but must be 2 or 3")

    grid = pv.StructuredGrid(X, Y, Z)
    grid.cell_data['MaterialIDs'] = (np.ones(grid.n_cells) * m).astype(np.int32)
    return grid


def create_layered_mesh(z_bounds, z_num_slices, z_material_ids, bounds, res, rank):
    z_layer = create_layers(z_bounds, z_num_slices, z_material_ids)
    all = [create_structured_grid(bounds, res, zs, mat, rank)
           for zs, mat in z_layer]
    layered_mesh = all[0].merge(all[1:])
    return layered_mesh


def write_mesh(file_name, mesh):
    # paravista is not able to write to other directory
    path = pathlib.Path(file_name)
    par = path.parent
    par.mkdir(parents=True, exist_ok=True)
    cwd = os.getcwd()
    os.chdir(par)
    pv.save_meshio(path.name, mesh)
    os.chdir(cwd)
    print('Saved to: ', file_name)


def create(layer_set_id, xy_res, layer_sets_csv, parameters_file, file_name, 
           factor, layers_folder, rank):
    
    layer_sets_df = model_layer_parameter(layer_set_id, layer_sets_csv, 
                                          parameters_file)
    z_num_slices = layer_sets_df['resolution'].values.tolist()[1:]
    z_num_slices = [x * factor for x in z_num_slices]
    material_ids = layer_sets_df['material_id'].values.tolist()[1:]
    heights = layer_sets_df.apply(height_in_folder(layers_folder), 
                                  axis=1).to_numpy()
    bounds_3D = get_bounds("{folder}00_KB.vtu".format(folder=layers_folder))
    layered_mesh = create_layered_mesh(heights, z_num_slices, material_ids, 
                                       bounds_3D, xy_res, rank)
    write_mesh(file_name, layered_mesh)


if __name__ == '__main__':
    print("started generation simplified structured mesh")
    if len(argv) < 8:
        print("Usage: " + argv[0] + "layer_set_id x_resolution dgr_folder")
        exit(1)

    layer_set_id = int(argv[1])
    x_res = float(argv[2])
    factor = int(argv[3])
    rank = int(argv[4])

    model_name = argv[5]
    layer_sets_csv = argv[6]
    input_layers_folder = argv[7] + "/"
    out_filenamepath = argv[8]
    parameters_file = argv[9]

    create(layer_set_id, x_res, layer_sets_csv, parameters_file, 
           out_filenamepath, factor, input_layers_folder, rank)
