"""
@author: florian.zill@ufz.de
"""

import pyvista as pv
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.tri as mpl_tri
import matplotlib.colors  # necessary for pyvista plot
import matplotlib.patheffects as pe
from scipy.interpolate import griddata
from ogs_tools import scaling
from ogs_tools import plot_tools as tools
from ogs_tools.OGS_param import OGS_param
from matplotlib.transforms import blended_transform_factory as btf
import matplotlib.ticker
from PIL import Image
from standard_plots import *


class SlicePlotter():

    labels = {2: {'x': 'Süd - Nord', 'y': 'Höhe über NN', 'z': ''},
              3: {'x': 'West - Ost', 'y': 'Süd - Nord', 'z': 'Höhe über NN'}}

    def __init__(self, pvd_path: str, layers: dict = None):
        self.pvd = pv.PVDReader(pvd_path)
        self.pvd_ts = np.array(self.pvd.time_values)
        self.fig = None
        self.total_min, self.total_max = (np.inf, -np.inf)
        self.dim = self.get_mesh(0).cell[0].dimension
        self.ax_normals = "z" if self.dim == 2 else "yxzi"
        ax_dict = {'x': 0, 'y': 1, 'z': 2, 'i': 3}
        self.ax_normal_id = [ax_dict[ax_n] for ax_n in self.ax_normals]
        scale_dict = {"t": "tight", "s": "scaled", "i": "image"}
        self.scale_type = [scale_dict[s] for s in "ttsi"]
        self.bilinear_cmap = False
        self.layers = layers
        pv.set_plot_theme("document")
        self.p = pv.Plotter()

    def get_value_range(self, time, param: OGS_param, cellData=False):
        self.pvd.set_active_time_value(time)
        mesh = self.pvd.read()[0]
        mask_param = param.name + "_active"
        if param.name == "velocity":
            mask_param = "pressure_active"
        if mask_param in mesh.cell_data.keys():
            mesh = mesh.cell_data_to_point_data()
            mesh = mesh.threshold(value=(1, 1), preference="point",
                                  scalars=mask_param, all_scalars=True)
        if cellData:
            mesh = mesh.cell_data_to_point_data()

        p_field = tools.apply_function_to_vtu(
            mesh, param.name, param.func, self.dim)
        p_min, p_max = (np.min(p_field), np.max(p_field))
        if self.bilinear_cmap:
            absmax = np.max(np.abs([p_min, p_max]))
            p_min, p_max = (-absmax, absmax)
        return np.array([p_min, p_max])

    def param_exists(self, param_name: str):
        self.pvd.set_active_time_value(self.pvd_ts[0])
        mesh = self.pvd.read()[0]
        if param_name not in mesh.point_data.keys():
            if param_name not in mesh.cell_data.keys():
                print(f"no data named {param_name} available")
                return False
        return True

    def store_total_value_range(self, plot_cfg: OGS_param, t_arr=None, cellData=False) -> None:
        self.pt_min, self.pt_max = (np.inf, -np.inf)
        if self.param_exists(plot_cfg.param):
            print("Analyzing total min and max value for " + plot_cfg.param)
            if t_arr is None:
                t_arr = self.pvd_ts
            p_minmax = np.zeros((len(t_arr), 2))
            for i, t in enumerate(t_arr):
                print(f"{100. * i / (len(t_arr) - 1):.1f} %", end='\r')
                p_minmax[i] = self.get_value_range(t, plot_cfg, cellData)
            min_id = np.argmin(p_minmax[:, 0])
            max_id = np.argmax(p_minmax[:, 1])
            self.pt_min = p_minmax[min_id, 0]
            self.pt_max = p_minmax[max_id, 1]
            print(f"total min at t = {t_arr[min_id]} of {self.pt_min}")
            print(f"total max at t = {t_arr[max_id]} of {self.pt_max}")
            print("\ndone!")

    def get_cmap(self, param: OGS_param, levels: np.ndarray):
        cmaps = {"displacement": "Greens", "temperature": "plasma",
                 "pressure": "Blues", "displacement_rate": "PRGn"}
        if self.bilinear_cmap:
            cmaps = {"displacement": "PRGn"}
        if param.name in cmaps.keys():
            cm = cmaps[param.name]
        else:
            cm = "coolwarm"
        return plt.cm.get_cmap(cm, len(levels)+2)

    def get_mesh(self, t: float):
        self.pvd.set_active_time_value(t)
        return self.pvd.read()[0]

    def mask_deactivated_subdomains(self, mesh: pv.UnstructuredGrid, param_name: str):
        mask_param = "pressure_active"
        if param_name in ["pressure", "velocity"] and \
                mask_param in mesh.cell_data.keys() and \
                len(mesh.cell_data[mask_param]):
            return mesh.ctp(True).threshold(value=[1, 1], scalars=mask_param)
        return mesh

    def plot_isometric(self, param: OGS_param, t: float, p_min: float, p_max: float):
        levels = scaling.get_levels(p_min, p_max, 11)
        cmap = self.get_cmap(param, levels)

        mesh = self.get_mesh(t)
        mesh = self.mask_deactivated_subdomains(mesh, param.name)
        mesh.point_data.active_scalars_name = param.name
        vtu_vals = tools.apply_function_to_vtu(mesh, param.name, param.func, self.dim)
        mesh.point_data[param.name] = param.transform(vtu_vals)
        mesh = mesh.scale([1.0, 1.0, 15.0], inplace=False)
        # add arg show_edges=True if you want to see the cell edges
        pv.start_xvfb()
        self.p = pv.Plotter(off_screen=True, notebook=False)
        _, x_max, y_min, y_max, z_min, z_max = mesh.bounds
        cx, cy, _ = mesh.center
        mesh_clipped = mesh.clip_box([cx, x_max, y_min, cy, z_min, z_max])
        self.p.add_mesh(mesh_clipped, cmap=cmap, clim=[levels[0], levels[-1]],
                        lighting=False, component=param.component)
        mesh_surf = mesh.extract_surface()
        for mat_id in np.unique(mesh.cell_data["MaterialIDs"]):
            mesh_id = mesh_surf.threshold(mat_id, "MaterialIDs")
            self.p.add_mesh(mesh_id.extract_feature_edges(), color="k")
        self.p.camera.azimuth += 270
        self.p.remove_scalar_bar()
        self.p.show()
        pv_im = self.p.image
        im = Image.fromarray(pv_im)
        im = tools.trim(im)
        return im

    def plot_contour(self, ax_id: int, vtu_path: str, origin: np.ndarray, slice: bool, style: str, lw: int):
        vtu = pv.XMLUnstructuredGridReader(vtu_path).read()
        if self.dim == 2:
            contour_vtu = vtu.extract_surface().strip(join=True)
        else:
            if slice:
                vtu = vtu.slice(normal=self.ax_normals[ax_id], origin=origin)
            else:
                vtu = vtu.extract_feature_edges()
            contour_vtu = vtu.strip(join=True)

        x_id, y_id = np.delete([0, 1, 2], self.ax_normal_id[ax_id])
        x, y = 1e-3*contour_vtu.points[contour_vtu.lines[1:]].T[[x_id,y_id]]
        self.fig.axes[ax_id].plot(x, y, style, lw=lw)

    def plot_layer_boundaries(self, ax_id: int, mesh: pv.UnstructuredGrid, origin: np.ndarray):
        if self.dim == 3:
            slice = mesh.slice(self.ax_normals[ax_id], origin)
        else:
            slice = mesh.extract_surface()
        
        mat_ids = np.unique(slice.cell_data["MaterialIDs"])
        x_id, y_id = np.delete([0, 1, 2], self.ax_normal_id[ax_id])
        ax = self.fig.axes[ax_id]
        for mat_id in mat_ids:
            m_i = slice.threshold((mat_id, mat_id), "MaterialIDs")
            edges = m_i.extract_feature_edges().strip(join=True, max_length = 10000)
            x_b, y_b = 1e-3*edges.points[edges.lines %
                                         edges.n_points].T[[x_id, y_id]]
            ax.plot(x_b, y_b, "-k", lw=0.5)
            y_pos = np.mean(y_b[x_b == x_b.min()])

            if self.layers:
                outline = [pe.withStroke(linewidth=1, foreground='k')]
                plt.text(0.01, y_pos, self.layers[mat_id], 
                         fontsize=plt.rcParams['font.size']*0.75,
                         transform=btf(ax.transAxes, ax.transData), 
                         color="w", weight="bold", ha="left", va="center", 
                         path_effects=outline)


    def plot_arrows(self, ax: plt.Axes, mesh: pv.UnstructuredGrid, 
                    param: OGS_param, x_id: int, y_id: int, z_id: int):
        x = np.linspace(mesh.bounds[x_id*2], mesh.bounds[x_id*2+1], 50)
        y = np.linspace(mesh.bounds[y_id*2], mesh.bounds[y_id*2+1], 50)
        z = np.linspace(mesh.bounds[z_id*2], mesh.bounds[z_id*2+1], 50)
        interp = mesh.interpolate(pv.StructuredGrid(x, y, z))
        p_field = interp.point_data[param.name]
        if "pressure_active" in interp.point_data.keys():
            mask = interp.point_data["pressure_active"]
        else:
            mask = np.ones(len(interp.points))
        x_grid, y_grid = np.meshgrid(x, y)
        if (param.func is None or param.func == "log") and \
                len(p_field.shape) == 2 and p_field.shape[1] == self.dim:
            val = param.transform(p_field.T[[x_id, y_id]]).T     
            u_grid = griddata((interp.points[:, x_id], interp.points[:, y_id]), 
                                val[:, 0], (x_grid, y_grid), method='linear')
            v_grid = griddata((interp.points[:, x_id], interp.points[:, y_id]), 
                                val[:, 1], (x_grid, y_grid), method='linear')
            val_grid = griddata((interp.points[:, x_id], interp.points[:, y_id]), 
                                np.linalg.norm(val, axis=1), (x_grid, y_grid), method='linear')
            mask_grid = griddata((interp.points[:, x_id], interp.points[:, y_id]), 
                                mask, (x_grid, y_grid), method='cubic')
            # interpolation of mask cell_data to points is a bit tricky
            # cubic interpolation and a threshold of 0.4 gives good results
            u_grid[mask_grid < 0.4] = np.nan
            v_grid[mask_grid < 0.4] = np.nan
            lw = 2.5 * val_grid / np.max(np.linalg.norm(val, axis=1))
            ax.streamplot(1e-3*x_grid, 1e-3*y_grid, u_grid, v_grid, color="k", linewidth=lw)
            plt.margins(-0.01, -0.01) #otherwise it shrinks the the plot content

    def plot_slice(self, param: OGS_param, ax_id: int, t: float, origin, ax: plt.Axes, levels: np.ndarray, p_min: float, p_max: float):

        x_id, y_id = np.delete([0, 1, 2], self.ax_normal_id[ax_id])
        mesh_t = self.get_mesh(t)
        mesh_t = self.mask_deactivated_subdomains(mesh_t, param.name)

        if self.dim == 3:
            slice_tri = mesh_t.slice(self.ax_normals[ax_id], origin, True)
        else:
            slice_tri = mesh_t.triangulate().extract_surface()
        if not slice_tri.point_data:
            return None
        x, y = 1e-3*slice_tri.points.T[[x_id, y_id]]
        tri = slice_tri.faces.reshape((-1, 4))[:, 1:]

        vtu_vals = tools.apply_function_to_vtu(slice_tri, param.name, param.func, self.dim)
        values = param.transform(vtu_vals)
        cm = ax.tricontourf(x, y, tri, values, levels=levels,
                            cmap=self.get_cmap(param, levels), extend="both")
        if param.name == "temperature" and np.min(values) < 0.:
            ax.tricontour(x, y, tri, values, levels=[0], colors="cyan")
        if self.bilinear_cmap and p_min < 0. and p_max > 0.:
            ax.tricontour(x, y, tri, values, levels=[0], colors="w")

        if y_id == 2 or self.dim == 2:
            self.plot_layer_boundaries(ax_id, mesh_t, origin)
        if self.dim == 3:
            ax.hlines(1e-3*origin[y_id], np.min(x), np.max(x), "k", "--", alpha=0.5)
            ax.vlines(1e-3*origin[x_id], np.min(y), np.max(y), "k", "--", alpha=0.5)

        self.plot_arrows(ax, slice_tri, param, x_id, y_id, self.ax_normal_id[ax_id])

        ax.set_xlabel(f'{"xyz"[x_id]} ({self.labels[self.dim]["xyz"[x_id]]}) / km')
        ax.set_ylabel(f'{"xyz"[y_id]} ({self.labels[self.dim]["xyz"[y_id]]}) / km')

        return cm

    def create_ax(self, i: int, n_plots: int, scale_type: str):
        rows = [1 if n_plots < 4 else [1, 1, 2, 2][i]][0]
        cols = min(n_plots, 3)
        idx = [i+1 if i < 3 else 6][0]
        ax = self.fig.add_subplot(rows, cols, idx)
        ax.axis(scale_type)
        ax.autoscale()
        if self.dim == 3:
            ax.yaxis.set_visible(i != 1)
            if i == 2:
                ax.yaxis.tick_right()
                ax.yaxis.set_label_position("right")
        return ax

    
    def add_colorbar(self, c_map, levels: np.ndarray, param: OGS_param):
        base = np.floor(abs(np.log10(abs(levels[-1]))))
        ticklevels = levels if base > 3 else levels[1:-1]
        cb = plt.colorbar(c_map, ax=self.fig.axes, ticks=ticklevels,
                          drawedges=True, location='left',
                          spacing='uniform', extendfrac=0)
        cb.ax.tick_params()
        if base <= 3:
            cb.ax.text(0.5, -0.01, f"{0.+levels[0]:.3g}",
                    transform=cb.ax.transAxes, va='top', ha='center')
            cb.ax.text(0.5, 1.005, f"{0.+levels[-1]:.3g}",
                    transform=cb.ax.transAxes, va='bottom', ha='center')
        cb.set_label(f'{param.get_str()} / {param.unit}')
        cb.ax.yaxis.set_major_formatter(
            matplotlib.ticker.ScalarFormatter(useMathText=True))

    def plot_slices(self, origin: np.ndarray, param: OGS_param, t: float, 
                    scale_t=False, fig_scale: float = 1.):

        plt.rcParams['font.size'] = 24 * fig_scale

        self.bilinear_cmap = False
        if param.func is not None and param.func.isdigit():
            self.bilinear_cmap = True
        if scale_t:
            p_minmax = np.array([self.pt_min, self.pt_max])
        else:
            p_minmax = self.get_value_range(t, param)
        p_min, p_max = param.transform(p_minmax)

        levels = scaling.get_levels(p_min, p_max, 11)

        n_plots = len(self.ax_normals)
        aspect = np.array([{2: 24, 3: 24}[self.dim], 11]) * fig_scale
        self.fig = plt.figure(dpi=200 * fig_scale, figsize=aspect)
        self.fig.patch.set_alpha(1)
        c_maps = []

        for i in range(n_plots):

            ax = self.create_ax(i, n_plots, self.scale_type[i])

            if self.ax_normals[i] == 'i':
                ax.imshow(self.plot_isometric(param, t, p_min, p_max))
                ax.axis('off')
            else:
                c_maps += [self.plot_slice(param, i, t, origin, ax,
                                           levels, p_min, p_max)]

        self.fig.axes[0].get_shared_y_axes().join(*self.fig.axes[0:2])

        def get_aspect(ax):
            figW, figH = ax.get_figure().get_size_inches()
            _, _, w, h = ax.get_position().bounds
            disp_ratio = (figH * h) / (figW * w)
            data_ratio = (ax.get_ylim()[1] - ax.get_ylim()
                        [0]) / (ax.get_xlim()[1] - ax.get_xlim()[0])
            return disp_ratio / data_ratio
        scale_ratio = get_aspect(self.fig.axes[0])
        if scale_ratio != 1.:
            plt.text(1, -0.1, "Aspektverhältnis {:.1f}:1".format(scale_ratio),
                         transform=btf(ax.transAxes, ax.transAxes),
                         fontsize=plt.rcParams['font.size'],
                         ha="right", va="center")

        self.add_colorbar(c_maps[0], levels, param)
