import argparse
import pyvista as pv
import numpy as np


def create_subdomains(vtu: str, threshold_angle: float, out_path: str, suffix: str, subdomains: list):
    mesh = pv.read(vtu)

    if not np.all(mesh.celltypes == 3):
        raise Exception("Mesh needs to only contain lines (id=3), but "
                        + f"contains ids {list(set(mesh.celltypes))} instead.")

    n_cells = mesh.n_cells
    cell_points = np.array([mesh.cell[i].points for i in range(n_cells)])

    ordered_cell_ids = [0]
    current_id = 0
    for _ in range(n_cells):
        next_cell_id = np.argmax(np.all(np.equal(cell_points[current_id, 1],
                                                 cell_points[:, 0]), axis=1))
        ordered_cell_ids += [next_cell_id]
        current_id = next_cell_id
    
    vectors = np.diff(cell_points[ordered_cell_ids], axis=1)[:, 0]
    angles = np.degrees(np.arctan2(vectors[:, 1], vectors[:, 0]))
    angles_diff_pos = np.abs(np.diff(angles))
    angles_diff = [360. - ang if ang > 180 else ang for ang in angles_diff_pos]
    corners = np.where(np.abs(angles_diff) > threshold_angle)[0]
    corners = np.append(corners, corners[0] + n_cells)

    results = []
    for i in range(len(corners) - 1):
        cell_ids = np.array([ordered_cell_ids[(k+1) % n_cells]
                             for k in range(corners[i], corners[i+1])], dtype=int)
        results.append(mesh.extract_cells(cell_ids))
    centers = np.array([mesh.center for mesh in results])
    boundary_dict = {"top":     np.argmax(centers[:, 1]),
                     "bottom":  np.argmin(centers[:, 1]),
                     "left":    np.argmin(centers[:, 0]),
                     "right":   np.argmax(centers[:, 0])}
    for subdomain in subdomains:
        filename = out_path + subdomain + suffix + ".vtu"
        results[boundary_dict[subdomain]].save(filename)


if __name__ == '__main__':
    """
    Splits a boundary mesh consisting of lines into the subdomains
    top, bottom, left and right, by detecting corners in the mesh.
    The mesh has to be in the xy plane.
    """

    parser = argparse.ArgumentParser()
    parser.add_argument("vtu", help="boundary vtu consting of line elements")
    parser.add_argument("ang", type=float, help="minimum angle for corners")
    parser.add_argument("out_path", help="output path", default = "./")
    parser.add_argument("suffix", help="output files suffix", default = "")
    parser.add_argument("subdomains", help="subdomains to create", 
                        default=["top", "bottom", "left", "right"], nargs="*")
    args = parser.parse_args()

    create_subdomains(args.vtu, args.ang, args.out_path, args.suffix, args.subdomains)