import argparse
import pathlib
import pyvista as pv
import numpy as np
import subprocess

from toy import model_layer_parameter, get_bounds
from os.path import exists
from sys import exit
from swap_coordinates import rotate


def intermediate_layers(surfaces, model_dir, xyres_str, num_sub_layers):
    rastered_surfaces = [model_dir / (surface.stem + '.asc') for surface in surfaces]
    vtu_rastered_surfaces = list(zip(surfaces, rastered_surfaces))

    for surface, rastered in vtu_rastered_surfaces:
        subprocess.call(['Mesh2Raster', '-i', surface,
                         '-o', rastered, '-c', str(xyres_str)])
    layers = [(top_surface, buttom_surface,
               buttom_surface.with_name(top_surface.stem + '-' + buttom_surface.stem + '_layer.asc'), num_sublayer) for
              top_surface, buttom_surface, num_sublayer in
              zip(rastered_surfaces, rastered_surfaces[1:], num_sub_layers)]
    intermediate_layers = [(top, buttom, layer, num - 1) for top, buttom, layer, num in layers if num > 1]
    for top, buttom, layer, num_sublayer in intermediate_layers:
        subprocess.call(
            ['createIntermediateRasters', '--file1', top, '--file2', buttom, '-o', layer, '-n', str(num_sublayer)])

    all_layers = []
    for top, buttom, layer, num in layers:
        all_layers.append(top)
        if num>1:
            for i in range(0,num-1):
                all_layers.append(layer.with_name(layer.stem + str(i) + '.asc'))
    all_layers.append(buttom)
    
    rastered_layers_txt = model_dir / "rastered_layers.txt"
    with open(rastered_layers_txt, "w") as file:
        file.write("\n".join(str(item) for item in all_layers))


    repo_layers_txt = model_dir / "repo_layers.txt"

    repo_exists=any(layer for layer in all_layers if "repo.asc" in str(layer))
    if repo_exists:
        repo_index = [idx for idx, l in enumerate(all_layers) if "repo.asc" in str(l)][0]
        repo_layers = all_layers[repo_index:repo_index+2]
    else:
        repo_layers=[]

    with open(repo_layers_txt, "w") as file:
            file.write("\n".join(str(item) for item in repo_layers))


    return all_layers, rastered_layers_txt


def raster(dir_str, gml_str, xres):
    dir = pathlib.Path(dir_str)
    gml = pathlib.Path(gml_str)
    geo = dir / "Rect.geo"
    subprocess.call(['geometryToGmshGeo', '-i', gml, '-o', geo, '--mesh_density_scaling_at_points', str(xres)])
    msh = dir / "Rect.msh"
    subprocess.call(['gmsh', geo, '-2', 'algo', 'meshadapt', '-format', 'msh22', '-o', msh])
    vtu = dir / "Raster.vtu"
    subprocess.call(['GMSH2OGS', '-i', msh, '-o', vtu])
    return vtu


def layer_names(folder):
    def layer_names_folder(row):
        txt = "{folder}{layer_id:02d}_{model_unit}.vtu"

        layer_name = txt.format(layer_id=row['layer_id'],
                                model_unit=row['model_unit'], folder=folder)
        file_exists = exists(layer_name)
        if file_exists is False:
            print(layer_name, 'does not exist.')
            exit(1)
        return pathlib.Path(layer_name)

    return layer_names_folder

def normalize_resolution(xy_res, layers_dir_str):
    model_bounds_3D = get_bounds("{folder}00_KB.vtu".format(folder=layers_dir_str))
    print (model_bounds_3D)
    x_extent = model_bounds_3D[0][1]-model_bounds_3D[0][0]
    print (x_extent)
    normalized =  xy_res/x_extent
    if normalized>0.16:
        print("Warning: normalized xy resolution should be lower than 0.16 but is: ", normalized)
    return normalized        


def create(layer_set_id, xy_res, layer_set_csv, parameters_file, layers_dir_str, gml_str,
           domain_mesh_file, dim):
    layer_sets_df = model_layer_parameter(layer_set_id, layer_set_csv, parameters_file)

    vtu_names = layer_sets_df.apply(layer_names(layers_dir_str), axis=1).to_numpy()
    z_num_slices = layer_sets_df['resolution'].values.tolist()[1:]
    material_ids = layer_sets_df['material_id'].values.tolist()[1:]
    model_dir = pathlib.Path(domain_mesh_file).parent
    model_dir.mkdir(parents=True, exist_ok=True)

    # 300 is a good value to match the discretization of the GOCAD models
    layers, layers_txt = intermediate_layers(vtu_names, model_dir, 30, z_num_slices)
    xy_res_normalized = normalize_resolution(xy_res=xy_res, layers_dir_str=layers_dir_str)
    raster_vtu = raster(model_dir, gml_str, xy_res_normalized)
    domain_vtu_intermediate = model_dir / "domain_intermediate.vtu"

    subprocess.call(['createLayeredMeshFromRasters', '-i', raster_vtu, '-r', layers_txt, '-o', domain_vtu_intermediate])
    materials_in_domain = sum([[id] * num for id, num in zip(material_ids, z_num_slices)], [])

    pv_mesh = pv.XMLUnstructuredGridReader(domain_vtu_intermediate).read()
    intermediate_vtu_ids = list(set(pv_mesh.cell_data["MaterialIDs"]))
    # reversed bc createLayeredMeshFromRasters starts numbering from the bottom
    # up, but we number the layers from top to bottom
    id_mapping = {old_id: new_id for old_id, new_id in zip(
        intermediate_vtu_ids, materials_in_domain[::-1])}
    new_ids = [id_mapping[old_id] for old_id in pv_mesh.cell_data["MaterialIDs"]]
    pv_mesh.cell_data["MaterialIDs"].setfield(new_ids, np.uint32)
    pv_mesh.save(domain_vtu_intermediate)

    subprocess.call(['checkMesh', domain_vtu_intermediate])
    if dim == 2:
        subprocess.call(['reviseMesh', '-i', domain_vtu_intermediate, 
                         '-o', domain_vtu_intermediate])
        subprocess.call(['ExtractSurface', '-i', domain_vtu_intermediate,
                         '-o', domain_vtu_intermediate, 
                        '-x', '-1', '-y', '0', '-z', '0', '-a', '1'])
        rotate(domain_vtu_intermediate, domain_vtu_intermediate, "yzx")
        subprocess.call(['NodeReordering', '-i', domain_vtu_intermediate, 
                         '-o', domain_mesh_file])
    else:
        subprocess.call(['reviseMesh', '-i', domain_vtu_intermediate, 
                         '-o', domain_mesh_file])
    

if __name__ == '__main__':
    print("started generation prism mesh")

    parser = argparse.ArgumentParser()
    parser.add_argument("layer_set_csv", help="layer set csv")
    parser.add_argument("parameters_file", help="parameter table csv")
    parser.add_argument("layers_dir", help="e.g. results/surface_data/Ton-Nord/")                        
    parser.add_argument("gml_file", help="gml file for rastering")
    parser.add_argument("domain_mesh_file", help="resulting vtu")
    parser.add_argument("layer_set_id", type=int)
    parser.add_argument("x_res", type=float, help="horizontal resolution")
    parser.add_argument("dim", type=int)
    args = parser.parse_args()

    create(args.layer_set_id, args.x_res, args.layer_set_csv, args.parameters_file, 
           args.layers_dir + "/", args.gml_file, args.domain_mesh_file, args.dim)
