import argparse
from glaciationBCs import constants_AREHS as ac
from glaciationBCs import glacierclass_AREHS as glc
from SlicePlotter import SlicePlotter
from ogs_tools import plot_tools as tools
from ogs_tools.OGS_param import OGS_param
import pyvista as pv
import pandas as pd
import numpy as np

pv.start_xvfb()

def DE_to_EN(s: str) -> str:
    DE_EN_dict = {"interglacial period":       "Interglaziale Phase",
                  "permafrost development":    "Permafrost Entwicklung",
                  "permafrost-only period":    "Permafrost Phase",
                  "glacier advance":           "Gletschervorschub",
                  "glacier dormancy":          "Gletscherstillstand",
                  "glacier retreat":           "Gletscherrückgang"}
    if s in DE_EN_dict.keys():
        return DE_EN_dict[s]
    print(f"German translation for {s} missing!")
    return s

def get_glacier(bounds, dim):
    u_min = bounds[dim*2-4] # dim=2: u_min = x_min (0), dim=3: u_min = y_min (2)
    u_max = bounds[dim*2-3] # dim=2: u_max = x_max (1), dim=3: u_max = y_max (3)
    L_max = ac.glacial_advance * (u_max - u_min)
    return glc.glacier(L_max, ac.H_max, u_max, ac.t_)


def plot_glacier(glacier: glc.glacier, i: int, Plotter: SlicePlotter,
                 origin: list, t: float, scaling = 0.2):
    if Plotter.dim == 3:
        mesh = Plotter.get_mesh(0).slice(Plotter.ax_normals[i], origin)
    else:
        mesh = Plotter.get_mesh(0).triangulate().extract_surface()
    XYZ = mesh.extract_feature_edges().points.T
    # get top contour of model domain
    df = pd.DataFrame(np.delete(XYZ, Plotter.ax_normal_id[i], axis=0).T)
    x_vals = df.groupby([0])[0].agg(np.mean).to_numpy()
    y_vals = df.groupby([0])[1].agg(np.max).to_numpy()
    u_vals = x_vals
    if Plotter.ax_normal_id[i] == 1:
        u_vals = x_vals*0 + np.mean(XYZ[1, :])
    contour = np.array([glacier.local_height(u, t) for u in u_vals]) * scaling
    contour += y_vals
    Plotter.fig.axes[i].set_ylim(top = 1e-3 * (np.max(y_vals) + ac.H_max * scaling))
    Plotter.fig.axes[i].fill_between(1e-3*x_vals, 1e-3*y_vals, 1e-3*contour, facecolor="lightgrey")
    y_ticks = np.array(Plotter.fig.axes[i].get_yticks())
    y_ticks = y_ticks[y_ticks <= 0]
    Plotter.fig.axes[i].set_yticks(y_ticks)


def get_prefered_center(pvd_path: str, pref_vtu_path=None):
    if pref_vtu_path:
        return pv.XMLUnstructuredGridReader(pref_vtu_path).read().center
    return pv.PVDReader(pvd_path).read()[0].center


def create_plots(pvd_path: str, dim: int, param_func: str, 
                 ts: int, layers: dict = None, repo_path: str = None, 
                 out_png: str = None, fig_scale: float = 1.):
    Plotter = SlicePlotter(pvd_path, layers)
    ts_pvd = tools.closest_values(Plotter.pvd_ts, [ts], ac.s_a)[0]
    ts_str = f"{ts_pvd/ac.s_a:.0f}".zfill(6)
    center = get_prefered_center(pvd_path, repo_path)
    glacier = get_glacier(Plotter.get_mesh(0).bounds, dim)
    param = OGS_param(param_func)
    if not Plotter.param_exists(param.name):
        return None
    Plotter.plot_slices(center, param, ts_pvd, fig_scale=fig_scale)
    for i in range(dim - 1):
        plot_glacier(glacier, i, Plotter, center, ts_pvd)
    for i in range(1 if dim == 2 else 3):
        if repo_path:
            slice_repo = False if i == 2 else True
            Plotter.plot_contour(i, repo_path, center, slice_repo, "k--", 3)

    Plotter.fig.axes[0].set_title(f"t = {ts_str} a", loc='left', y=1.02)
    if "glacialcycle" in pvd_path:
        gp_str = DE_to_EN(glacier.tcr_h.stage_control(ts_pvd))
        Plotter.fig.axes[dim == 3].set_title(gp_str, loc='right', y=1.02)
    else:
        Plotter.fig.axes[dim == 3].set_title("init stage", loc='right', y=1.02)
    if out_png:
        tools.save_plot(Plotter.fig, out_png)
        print(f"Saved plots for {param.name} to {out_png}.")
    return Plotter.fig


if __name__ == '__main__':
    """ Generates plots for a specified parameter """

    parser = argparse.ArgumentParser()
    parser.add_argument("out_png", help="full file path for resulting png")
    parser.add_argument("pvd", type=str, help="full path to pvd")
    parser.add_argument("dim", type=int, help="dimension of model")
    parser.add_argument("model", type=str, help="e.g. Ton-Nord, Salz-Kissen")
    parser.add_argument("param_func", type=str, help="parameter_function")
    parser.add_argument("csv", type=str, help="full path to parameter table")
    parser.add_argument("ts", type=int, help="timestep to plot")
    parser.add_argument("repo", type=str, nargs='?', default=None, 
                        help="full path to repo vtu")
    args = parser.parse_args()

    layers = tools.layer_names(args.model, args.csv)

    info = create_plots(args.pvd, args.dim, args.param_func,
                        args.ts, layers, args.repo, args.out_png)
