import numpy as np
import pandas as pd
import pyvista as pv
import io
from PIL import Image


def closest_values(actual, vals, factor):
        indices = [np.abs(actual - val * factor).argmin() for val in vals]
        return np.unique([actual[index] for index in indices])


def apply_function(p: np.ndarray, func: str, v_max: float, 
                   pts_v: pv.PointSet) -> np.ndarray:
    if func is None:
        if len(p.shape) == 2:
            return np.linalg.norm(p, axis=1)
        return p
    if func.isdigit():
        return p[:, int(func)]
    if func == "von-Mises-stress":
        if p.shape[1] == 4:
            return np.sqrt(0.5*(
                np.square(p[:, 0] - p[:, 1]) + np.square(p[:, 1] - p[:, 2]) +
                np.square(p[:, 2] - p[:, 0])) + 3*(np.square(p[:, 3])))
        return np.sqrt(0.5*(
            np.square(p[:, 0] - p[:, 1]) + np.square(p[:, 1] - p[:, 2]) +
            np.square(p[:, 2] - p[:, 0])) + 3*(np.square(p[:, 3]) +
            np.square(p[:, 4]) + np.square(p[:, 5])))
    if func == "effective-pressure":
        return -(1./3.) * (p[:, 0] + p[:, 1] + p[:, 2])
    if func == "qp-ratio":
        return apply_function(p, "von-Mises-stress", v_max, pts_v) / \
            apply_function(p, "effective-pressure", v_max, pts_v)
    if func == "dynamic":
        if np.all(p == 0):
            return p*0.
        p_0 = 1e5
        beta_pw = 5e-10
        p_hs = (np.exp(1000. * 9.81 * beta_pw * (v_max - pts_v)) - 1.) / beta_pw
        return p - p_hs - p_0
    if func == "hydraulic-head":
        if np.all(p == 0):
            return p*0.
        p0 = 1e5
        return (p - p0) / (1000. * 9.81) + (pts_v - v_max)
    if func == "trace":
        return p[:, 0] + p[:, 1] + p[:, 2]    
    if func == "log":
        if len(p.shape) == 2:
            return np.log10(np.linalg.norm(p, axis=1))
        return np.log10(p, where=p > 0)
    raise ValueError("Invalid function type." + func)


def apply_function_to_vtu(vtu: pv.UnstructuredGrid, param: str, func: str, dim: int) -> np.ndarray:
    vert_id = dim - 1
    data = vtu.point_data[param]
    v_max = np.max(vtu.points[:, vert_id])
    return apply_function(data, func, v_max, vtu.points[:, vert_id])


def layer_names(model, csv_path):
    df = pd.read_csv(csv_path, delimiter=',')
    names = df["model_unit"].to_numpy()
    ids = df["layer_id"].to_numpy()
    names = [l.split("-")[0] for l in names]
    if model == "Ton-Nord" and "repo" in names:
        names[names.index("repo")] = "krb"
    return {id: layer for id, layer in zip(ids, names)}


def trim(im):
    a = np.array(im)[:,:,:3]  # keep RGB only
    m = np.any(a != [255, 255, 255], axis=2)
    coords = np.argwhere(m)
    y0, x0, y1, x1 = *np.min(coords, axis=0), *np.max(coords, axis=0)
    margin = 10
    bbox = (x0-margin, y0-margin, x1+1+margin, y1+1+margin)
    return im.crop(bbox)


def save_plot(fig, out_name):
    buf = io.BytesIO()
    fig.savefig(buf)
    buf.seek(0)
    im = trim(Image.open(buf))
    im.save(out_name)