import argparse
from glaciationBCs import constants_AREHS as ac
from ogs_tools import plot_tools as tools
from ogs_tools.OGS_param import OGS_param
import pyvista as pv
import matplotlib.pyplot as plt
import numpy as np
import vtuIO
from standard_plots import *
plt.rcParams['legend.fontsize'] = 'x-small'


def get_mat_ids(pvd_path: str, ids: list = None) -> list:
    mesh = pv.PVDReader(pvd_path).read()[0]
    mat_ids = list(set(mesh.cell_data["MaterialIDs"]))
    if ids is None:
        return mat_ids
    return [mat_ids[id] for id in ids]


def get_obs_pts(pvd_path: str, v_id: int, ids: list, v_pos: str = "mid"):
    mesh: pv.UnstructuredGrid = pv.PVDReader(pvd_path).read()[0]
    pA = mesh.points[mesh.find_closest_point(mesh.center)]
    pts = []
    for i in ids:
        mesh_i = mesh.threshold((i, i), "MaterialIDs")
        for j in range(v_id):
            pts_mask = np.array(mesh_i.points[:, j] >= pA[j]-1) & (
                                mesh_i.points[:, j] <= pA[j]+1)
            mesh_i = mesh_i.extract_points(np.where(pts_mask))
        mesh_i_pt = mesh_i.center
        if v_pos != "mid":
            v_off = {"top": 1, "bottom": 0}[v_pos]
            mesh_i_pt[v_id] = mesh_i.bounds[v_id*2 + v_off]
        pts += [mesh_i_pt]
    return pts


def get_timesteps(pvd_path: str, dim: int):
    pvdio = vtuIO.PVDIO(pvd_path, dim=dim, interpolation_backend="vtk")
    return pvdio.timesteps


def get_obs_pts_timeseries(pvd_path: str, dim: int, param_funcs: list[str], 
                           obs_pts: list) -> dict:
    mesh = pv.PVDReader(pvd_path).read()[0]
    params = [OGS_param(p_f) for p_f in param_funcs]
    for param in params:
        if param.name not in mesh.point_data.keys():
            print(f"no data named {param.name} available")
            return None

    v_max = np.max(mesh.points[:, dim - 1])
    obs_pts_dict = {f'pt{j}': point for j, point in enumerate(obs_pts)}
    pvdio = vtuIO.PVDIO(pvd_path, dim=dim, interpolation_backend="vtk")
    raw_data = pvdio.read_time_series(list(set(p.name for p in params)), obs_pts_dict)
    data = {k: {} for k in raw_data.keys()}

    for j, pt_key in enumerate(obs_pts_dict.keys()):
        pts_v = np.repeat(np.array(obs_pts)[j, dim-1], len(raw_data[pt_key][params[0].name]))
        for p, param_func in zip(params, param_funcs):
            data[pt_key][param_func] = p.transform(
                tools.apply_function(raw_data[pt_key][p.name], p.func, v_max, pts_v))

    return data


def plot_glacial_phases(ax: plt.Axes, t_scale = 1.):
    y_lims = ax.get_ylim()
    for t_p in [t for t in ac.t_ if t < ax.get_xlim()[1] / t_scale]:
        ax.vlines(t_p * t_scale, *y_lims, "k", "--", alpha=0.2)
    plt.margins(None, -0.001)    


def arehs_obs_pts(pvd: str, dim: int, obs_layer_ids: list) -> list:
    obs_pts = get_obs_pts(pvd, dim-1, obs_layer_ids, "bottom")
    pt_top = get_obs_pts(pvd, dim-1, get_mat_ids(pvd, [0]), "top")
    obs_pts = pt_top + obs_pts
    return obs_pts


def arehs_obs_layer_ids(pvd: str, layers: dict, ids = [0, 1, 2, -1]) -> list:    
    custom_ids = get_mat_ids(pvd, ids)
    for layer_id, name in layers.items():
        if name == "repo" or name == "krb":
            custom_ids = list(set(custom_ids + [layer_id]))
    return custom_ids


def arehs_obs_pts_timeseries(pvd: str, model: str, csv: str, dim: int, 
                                 param_funcs: list[str]) -> dict:
    layers = tools.layer_names(model, csv)
    obs_layer_ids = arehs_obs_layer_ids(pvd, layers)
    obs_pts = arehs_obs_pts(pvd, dim, obs_layer_ids)
    return get_obs_pts_timeseries(pvd, dim, param_funcs, obs_pts)


def arehs_obs_pt_labels(model: str, csv: str, pvd: str, dim: int) -> list:
    layers = tools.layer_names(model, csv)
    obs_layer_ids = arehs_obs_layer_ids(pvd, layers)
    obs_pts = arehs_obs_pts(pvd, dim, obs_layer_ids)
    v_values = np.around([op[dim-1] for op in obs_pts], 0)
    layers_names = ["top"] + [layers[id] for id in obs_layer_ids]
    labels = [f"{int(v)}m ü. NN ({l})" for v, l in zip(v_values, layers_names)]
    return labels


def plot(ts: np.ndarray, pts_timeseries: dict, param_func: str, 
         obs_pt_keys: list[str], labels: list[str] = None, **kwargs):
    if not labels is None and len(labels) != len(obs_pt_keys):
        raise Exception("labels and obs_pt_keys have to have the same length " +
            f"but are {len(labels)} and {len(obs_pt_keys)}")
    fig, ax = plt.subplots(dpi=200, figsize=[7, 2.5], facecolor='w')
    for pt_key, label in zip(obs_pt_keys, labels):
        ax.set_ylabel(OGS_param(param_func).get_str("\n", True))
        y = pts_timeseries[pt_key][param_func]
        ax.plot(ts[:len(y)], y, label=label, **kwargs)
    ax.legend()
    return fig, ax


def plot_arehs(pvd: str, dim: int, obs_pts_timeseries: dict, param_func: str,
               obs_pt_keys: list[str], labels: list[str] = None, **kwargs):
    yrs_scale = 1000
    t_scale = 1. / ac.s_a / yrs_scale
    ts = get_timesteps(pvd, dim) * t_scale
    fig, ax = plot(ts, obs_pts_timeseries, param_func, obs_pt_keys, labels, **kwargs)
    ax.set_xlabel(f'Zeit / {1./(t_scale*ac.s_a):.0f} Jahre')
    if "glacialcycle" in pvd:
        plot_glacial_phases(ax, t_scale)
    return fig, ax



if __name__ == '__main__':
    """ Generates plots for a specified parameter """

    parser = argparse.ArgumentParser()
    parser.add_argument("out_file", help="plot output file")
    parser.add_argument("pvd", type=str, help="full path to pvd")
    parser.add_argument("dim", type=int, help="dimension of model")
    parser.add_argument("model", type=str, help="e.g. Ton-Nord, Salz-Kissen")
    parser.add_argument("param_func", type=str, help="parameter_function")
    parser.add_argument("csv", type=str, help="full path to parameter table")
    args = parser.parse_args()

    obs_pts_timeseries = arehs_obs_pts_timeseries(args.pvd, args.model, args.csv, int(args.dim), [args.param_func])
    labels = arehs_obs_pt_labels(args.model, args.csv, args.pvd, int(args.dim))
    fig, ax = plot_arehs(args.pvd, int(args.dim), obs_pts_timeseries, args.param_func, obs_pts_timeseries.keys(), labels)
    tools.save_plot(fig, args.out_file)