import argparse
import pyvista as pv
import numpy as np

# https://www.grs.de/sites/default/files/publications/grs-568.pdf
# Resus GRS 568 (flach lagernd Salz) S. 131:
# Flächenbedarf von 2,7 km2, maximale Temperatur von 97 °C.
# https://www.grs.de/sites/default/files/publications/grs-569.pdf
# Resus GRS 569 (steil lagernd Salz) S. 126:
# Flächenbedarf von 1,2 km2, maximale Temperatur von 92 °C.
# https://www.grs.de/sites/default/files/publications/grs-570.pdf
# Resus GRS 570 (flach lagernd Salz), höhere Auslegungstemperatur) S. 130:
# Flächenbedarf von 1,8 km2, maximale Temperatur von 146 °C.

# https://www.grs.de/sites/default/files/publications/grs-571.pdf
# Resus GRS 571 (geringe Mächtigkeit Ton) S. 124:
# Flächenbedarf von 8,7 km2, maximale Temperatur von 87 °C.
# https://www.grs.de/sites/default/files/publications/grs-572.pdf
# Resus GRS 572 (große Mächtigkeit Ton) S. 125:
# Flächenbedarf von 9,1 km2, maximale Temperatur von 73,5 °C.
# https://www.grs.de/sites/default/files/publications/grs-573.pdf
# Resus GRS 573 (geringe Mächtigkeit Ton, hohe Auslegungstemperatur) S. 122:
# Flächenbedarf von 6 km2, maximale Temperatur von 100,5 °C.

def crop_repo(dim, in_vtu, model, out_vtu):

    vtu : pv.UnstructuredGrid = pv.XMLUnstructuredGridReader(in_vtu).read()
    model_bounds = vtu.bounds

    # repo edge length
    SK_len = (1.2**0.5)*1000 # 1095.4m x 1095.4m -> 1088 = 2^6 x 17
    Sf_len = (1.8**0.5)*1000 # 1341.6m x 1341.6m -> 1328 = 2^6 x 20.75
    TS_len = (8.7**0.5)*1000 # 2949.6m x 2949.6m -> 2944 = 2^7 x 23
    TN_len = (9.1**0.5)*1000 # 3016.6m x 3016.6m -> 3008 = 2^6 x 47
    TK_len = (4.0**0.5)*1000 # 2000m x 2000m -> 2000 = 2^4 x 5^3

    repo_edge_length = {"Salz-flach": Sf_len, "Salz-Kissen": SK_len,
        "Ton-Nord": TN_len, "Ton-Sued": TS_len, "Ton-Kristallin": TK_len}

    drepo_crop = np.array([-1, 1]) * repo_edge_length[model] / 2.
    repo_x_bounds = np.mean(model_bounds[0:2]) + drepo_crop
    repo_y_bounds = np.mean(model_bounds[2:4]) + drepo_crop
   
    pts_mask = np.array(vtu.points[:, 0] >= repo_x_bounds[0]) & (
                    vtu.points[:, 0] <= repo_x_bounds[1])
    # vtu.cell seems to be quite expensive, thus storing its result
    # significantly improves the performance
    cells = vtu.cell 
    cell_pts = np.array([cells[i].point_ids for i in range(vtu.n_cells)])
    cell_mask = pts_mask[cell_pts].all(axis=1)
    vtu = vtu.extract_cells(cell_mask)

    if dim == 2:
        cell_points = np.array([vtu.cell[i].points for i in range(vtu.n_cells)])
        ordered_cell_ids = [0]
        current_id = 0
        for _ in range(vtu.n_cells - 1):
            next_cell_id = np.argmax(np.all(np.equal(cell_points[current_id, 1],
                                                    cell_points[:, 0]), axis=1))
            ordered_cell_ids += [next_cell_id]
            current_id = next_cell_id        
        vectors = np.diff(cell_points[ordered_cell_ids], axis=1)[:, 0]
        angles = np.degrees(np.arctan2(vectors[:, 1], vectors[:, 0]))
        angles = np.minimum( np.abs(angles), np.abs(angles - 180.))
        if not np.any(angles <= 1):
            angles = angles * 0.
        vtu = vtu.extract_cells(np.argwhere((angles <= 1) & (
            cell_points[:, 0, 1] < np.mean(cell_points[:, 0, 1]))))
        size = np.sum(vtu.compute_cell_sizes().cell_data["Length"])**2

    elif dim == 3:
        pts_mask = np.array(vtu.points[:, 1] >= repo_y_bounds[0]) & (
                    vtu.points[:, 1] <= repo_y_bounds[1])
        cells = vtu.cell 
        cell_pts = np.array([cells[i].point_ids for i in range(vtu.n_cells)])
        cell_mask = pts_mask[cell_pts].all(axis=1)
        vtu = vtu.extract_cells(cell_mask)
        normals = vtu.extract_surface().cell_normals
        angles = np.rad2deg(np.arccos(np.dot(normals, [0, 0, 1])))
        angles = np.minimum( np.abs(angles), np.abs(angles - 180.))
        if not np.any(angles <= 1):
            angles = angles * 0.
        vtu = vtu.extract_cells(np.argwhere(angles <= 1))
        size = np.sum(vtu.compute_cell_sizes().cell_data["Area"])

    print("repo size = " + str(size))

    vtu.save(out_vtu)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("dim", help="dimension", type=int)
    parser.add_argument("in_vtu", help="input_vtu")
    parser.add_argument("model", help="model name")
    parser.add_argument("out_vtu", help="output_vtu")
    args = parser.parse_args()

    crop_repo(args.dim, args.in_vtu, args.model, args.out_vtu)
