#!/usr/bin/env python3

import gc
import pathlib
from string import ascii_lowercase as abcd
from textwrap import dedent

# developed with StagPy v0.15.0
from stagpy.stagyydata import StagyyData
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import matplotlib.colors as colors
import matplotlib.cm as cm
from matplotlib.patches import Rectangle
import matplotlib.patches as mpatches
import matplotlib.ticker as ticker
from mpl_toolkits.axes_grid1 import make_axes_locatable
import numpy as np
import pandas as pd
from scipy.optimize import curve_fit

mpl.rc('font', family='sans-serif', size=7, **{'sans-serif': ['Arial']})
mpl.rcParams['pdf.fonttype'] = 42

# Directory holding all cases
ROOT = pathlib.Path("/Volumes/LaCie/SPexplo")

# csv files with processed data
DIAG_FILE = pathlib.Path("diagnostics.csv")
POLYG_FILE = pathlib.Path("polyg_reg.csv")

# Other Nature figure sizes are 12 and 13.6 cm
INCHES_PER_CM = 1 / 2.54
ONE_COL = 8.9 * INCHES_PER_CM
TWO_COL = 18.3 * INCHES_PER_CM

# cases in the order they are numbered in the paper,
# ordered by nominal Ra and viscosity contrast
CASES = [
    "sp_ra1e+05__c_eta10.0",
    "sp_ra1e+05__c_eta30.0",
    "sp_ra1e+05__c_eta100.0",
    "sp_ra1e+05__c_eta300.0",
    "sp_ra1e+06__c_eta10.0",
    "sp_ra1e+06__c_eta30.0",
    "sp_ra1e+06__c_eta100.0",
    "sp_ra1e+06__c_eta300.0",
    "sp_ra1e+06__c_eta1000.0",
    "sp_ra1e+06__c_eta3000.0",
    "sp_ra1e+07__c_eta10.0",
    "sp_ra1e+07__c_eta30.0",
    "sp_ra1e+07__c_eta100.0",
    "sp_ra1e+07__c_eta300.0",
    "sp_ra1e+07__c_eta1000.0",
    "sp_ra1e+07__c_eta3000.0",
    "sp_ra1e+07__c_eta10000.0",
]

POLYG_CASES = [
    "sp_ra1e+06__c_eta100.0",
    "sp_ra1e+06__c_eta300.0",
    "sp_ra1e+07__c_eta300.0",
    "sp_ra1e+07__c_eta1000.0",
    "sp_ra1e+07__c_eta3000.0"
]

# dimensional parameters, in SI units
DEPTH = 5e3
DELTA_TEMP = 5
TSURF = 40 - DELTA_TEMP
ALPHA = 2e-3
KAPPA = 1.44e-7  # thermal diffusivity
K = 0.26  # thermal conductivity
G = 0.642  # acceleration of gravity
DENSITY = 990  # density
YEAR = 365 * 24 * 3600
KYEAR = YEAR * 1e3
YEARND = YEAR / (DEPTH * DEPTH / KAPPA)

dtime = DEPTH * DEPTH / KAPPA / KYEAR
dvel = KAPPA / DEPTH * YEAR * 1e2
dtemp = DELTA_TEMP
Tinf = TSURF
dq = K * DELTA_TEMP / DEPTH * 1e3
unit_d = ' (km)'
unit_t = r' (kyr)'
unit_v = r' (cm / yr)'
unit_T = r' (K)'
unit_q = r' ($\mathrm{mW/m}^2$)'

CMAP_TOPO = "PuOr_r"
CMAP_TEMP = "RdBu_r"


class MidpointNormalize(colors.Normalize):
    def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False):
        self.midpoint = midpoint
        colors.Normalize.__init__(self, vmin, vmax, clip)

    def __call__(self, value, clip=None):
        x, y = [self.vmin, self.midpoint, self.vmax], [0, 0.5, 1]
        return np.ma.masked_array(np.interp(value, x, y))


def sci_format(x):
    """Sanitized scientific format."""
    a, b = f'{x:.2e}'.split('e')
    b = int(b)
    if b >= 2 or b < 0:
        return rf'{a} \cdot 10^{{{b}}}'
    else:
        return f'{x:.1f}'


def sci_format_dollars(x):
    """Surround sci_format with dollars."""
    return "${}$".format(sci_format(x))


def compute_topo(snap):
    """Dimensionless topography."""
    ra_eff = snap.timeinfo['Raeff']
    p_0 = snap.fields['p'][..., -1, 0]  # p at cell center below surface
    p_1 = snap.fields['p'][..., -2, 0]  # p one cell below

    z_0 = snap.rprofs['r'].values[-1]  # cell center below surface
    z_1 = snap.rprofs['r'].values[-2]  # cell center below z_0
    dzt = 1 - (z_0 + z_1) / 2  # full cell, wall at center between p-points

    p_surf = ((z_0 - 1) * p_1 - (z_1 - 1) * p_0) / (z_0 - z_1)
    v3_under_surf = snap.fields['v3'][:-1, :-1, -1, 0]

    e_eta = snap.sdat.par['viscosity']['E_eta']
    if not np.isscalar(e_eta):
        e_eta = e_eta[0]

    top_bc = snap.sdat.par['boundaries']['topT_mode']
    if top_bc == 'iso':
        # temperature at surface set to 0 (Bi >> 1)
        eta_surf = np.exp(e_eta / 2)
    elif top_bc[:6] == 'sublim':
        biot = snap.sdat.par['boundaries']['sublim_Biot']
        if top_bc == 'sublimperiod':
            # Biot number varies with time
            time = snap.timeinfo['t']
            period = snap.sdat.par['boundaries']['sublim_period']
            biot = max(0, np.cos(time * 2 * np.pi / period) * biot)
        temp_under_surf = snap.fields['T'][..., -1, 0]
        # dT/dz + Bi Ts = 0
        temp_surf = temp_under_surf / (1 + biot * (1 - z_0))
        eta_surf = np.exp(e_eta / (temp_surf + 1) - e_eta / 2)
    else:
        raise ValueError(f"Top BC '{top_bc}' not supported")

    s_zz = -p_surf - 2 * eta_surf * v3_under_surf / dzt
    s_zz -= np.mean(s_zz)

    return -s_zz * ALPHA * DELTA_TEMP / ra_eff


def mask_map(field, geom, name):
    """Map connex components at the surface"""
    n_components = np.amax(field)
    cmap = cm.get_cmap("summer", n_components)
    cmap.set_bad()
    field[field == 0] = np.nan
    plt.pcolormesh(geom.x_centers, geom.y_centers, field,
                   cmap=cmap, shading='auto',
                   vmin=0.5, vmax=n_components+0.5)
    plt.axis('square')
    cbar = plt.colorbar(ticks=range(1, int(n_components) + 1))
    cbar.set_label("index")
    plt.savefig(name, bbox_inches='tight', dpi=100)
    plt.close()


def compute_hgrad(field, dx, dy):
    """"Computes horizontal gradient of 2D field"""
    difx = (np.roll(field, 1, axis=0) - field) / dx
    dify = (np.roll(field, 1, axis=1) - field) / dy
    return np.sqrt(difx * difx + dify * dify)


def map_topo_grad(case, plot_vel=False, annotation=None):
    stem = case.replace('sp_', '').replace('.0', '')
    pdffile = pathlib.Path(f'gradtopo_{stem}.pdf')
    if pdffile.exists():
        print('Figure file ', pdffile, ' exists. Skipping.')
        return
    sdat = StagyyData(ROOT / case)
    nxtot = sdat.par['geometry']['nxtot']
    nytot = sdat.par['geometry']['nytot']
    Lx, Ly = sdat.par['geometry']['aspect_ratio']
    dx = Lx / nxtot
    dy = Ly / nytot
    snap = sdat.snaps[-1]
    isnap = snap.isnap
    topofile = pathlib.Path(f'topo_{stem}_{isnap:02d}.npy')
    if topofile.exists():
        print('reading file ', topofile)
        topo = np.load(topofile)
    else:
        print('computing topography')
        topo = compute_topo(snap)
        np.save(topofile, topo)
    grad_topo = compute_hgrad(topo, dx, dy)
    annot = 'mean =' + sci_format_dollars(np.mean(grad_topo))
    if plot_vel:
        velfile = pathlib.Path(f'surfvel_{stem}_{isnap:02d}.npz')
        step = 64
        if velfile.exists():
            print('reading file ', velfile)
            vel = np.load(velfile)
            vel_x = vel['vel1'][::step, ::step]
            vel_y = vel['vel2'][::step, ::step]
        else:
            vx = snap.fields['v2'][:-1, :-1, -1, 0]
            vy = snap.fields['v1'][:-1, :-1, -1, 0]
            np.savez(velfile, vel1=vx, vel2=vy)
            vel_x = vx[::step, ::step]
            vel_y = vy[::step, ::step]
        # make it dimensional
        vel_x *= KAPPA / DEPTH * 100 * YEAR
        vel_y *= KAPPA / DEPTH * 100 * YEAR
    else:
        vel_x = None
        vel_y = None
    surf_map(grad_topo, snap, pdffile, CMAP_TOPO,
             "Topography slope",
             dimensional=False, velx=vel_x, vely=vel_y, annotation=annot)


def map_topo(case, plot_vel=False, annotation=None, dim=False):
    stem = case.replace('sp_', '').replace('.0', '')
    sdat = StagyyData(ROOT / case)
    snap = sdat.snaps[-1]
    isnap = snap.isnap
    topofile = pathlib.Path(f'topo_{stem}_{isnap:02d}.npy')
    if topofile.exist():
        print('reading file ', topofile)
        topo = np.load(topofile)
    else:
        print('computing topography')
        topo = compute_topo(snap)
        np.save(topofile, topo)
    if plot_vel:
        velfile = pathlib.Path(f'surfvel_{stem}_{isnap:02d}.npz')
        step = 64
        if velfile.exists():
            print('reading file ', velfile)
            vel = np.load(velfile)
            vel_x = vel['vel1'][::step, ::step]
            vel_y = vel['vel2'][::step, ::step]
        else:
            vx = snap.fields['v2'][:-1, :-1, -1, 0]
            vy = snap.fields['v1'][:-1, :-1, -1, 0]
            np.savez(velfile, vel1=vx, vel2=vy)
            vel_x = vx[::step, ::step]
            vel_y = vy[::step, ::step]
        # make it dimensional
        vel_x *= KAPPA / DEPTH * 100 * YEAR
        vel_y *= KAPPA / DEPTH * 100 * YEAR
    else:
        vel_x = None
        vel_y = None
    surf_map(topo, snap, "topo_"+stem+".pdf", CMAP_TOPO, "Topography",
             dimensional=dim, plot_velocity=plot_vel, annotation=annotation)


def surf_map(field, snap, name, cmap, cblabel, dimensional=False,
             velx=None, vely=None, annotation=None):
    """Map surface field."""
    geom = snap.geom
    xcoord = geom.x_centers.copy()
    ycoord = geom.y_centers.copy()
    if dimensional:
        xcoord *= DEPTH * 1e-3
        ycoord *= DEPTH * 1e-3
        dunit = ' (km)'
    else:
        dunit = ''
    vmin = np.amin(field)
    vmax = np.amax(field)
    fig, axis = plt.subplots()

    fig.set_figwidth(0.94 * ONE_COL)
    pcm = axis.pcolormesh(xcoord, ycoord, field, cmap=cmap, shading='gouraud',
                          norm=MidpointNormalize(vmin, vmax, 0.),
                          rasterized=True)

    divider = make_axes_locatable(axis)
    cax = divider.append_axes("right", size="3%", pad=0.05)
    fig.colorbar(pcm, cax=cax, label=cblabel)

    axis.axis('square')

    axis.set_xlabel('x'+dunit)
    axis.set_ylabel('y'+dunit)

    if annotation is not None:
        plt.text(0.01, 0.99, annotation, ha='left', va='top',
                 transform=axis.transAxes, fontsize=8, weight='bold')

    if not dimensional:
        ticks = np.arange(0, 17, 2)
        axis.set_xticks(ticks)
        axis.set_yticks(ticks)

    if velx is not None and vely is not None:
        step = 64
        velscale = 20
        annot = r'{}'.format(velscale)
        if dimensional:
            annot += r' cm/yr'
        quiv = axis.quiver(xcoord[::step], ycoord[::step], velx, vely)
        axis.quiverkey(quiv, 5, 82, velscale, annot, labelpos='E',
                       coordinates='data')

    plt.savefig(name, bbox_inches='tight', dpi=300)
    plt.close()


def surf_map2(fields, snap, name, cmaps, cblabels, dimensional=False,
              velx=None, vely=None, annotation=None, time=None, **kwargs):
    """Map 2 surface fields side-by-side."""
    geom = snap.geom
    xcoord = geom.x_centers.copy()
    ycoord = geom.y_centers.copy()
    if dimensional:
        xcoord *= DEPTH * 1e-3
        ycoord *= DEPTH * 1e-3
        dunit = ' (km)'
        unit_t = ' (kyr)$'
    else:
        dunit = ''
        unit_t = ''
    f, axes = plt.subplots(1, 2, sharey=True)
    f.set_figwidth(TWO_COL)

    if time is not None:
        f.suptitle(r'$t = '+sci_format(time)+unit_t, y=0.8)

    for ax, field, cmap, cblabel in zip(axes, fields, cmaps, cblabels):
        pcm = ax.pcolormesh(xcoord, ycoord, field, cmap=cmap, shading='auto',
                            norm=MidpointNormalize(midpoint=0.),
                            rasterized=True, **kwargs)
        ax.axis('square')
        ax.set_xlabel('x'+dunit)
        if ax == axes[0]:
            ax.set_ylabel('y'+dunit)
        if annotation is not None:
            ax.text(0.01, 0.99, annotation, ha='left', va='top',
                    transform=ax.transAxes, weight='bold')
        if not dimensional:
            ticks = np.arange(0, 17, 2)
            ax.set_xticks(ticks)
            ax.set_yticks(ticks)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="3%", pad=0.05)
        f.colorbar(pcm, cax=cax, label=cblabel)

    plt.savefig(name, bbox_inches='tight', dpi=300)
    plt.close()


def plot_profile(snap, stpro, xlabel, filename):
    """Plot z-profile of stpro (string type) for a given snapshot"""
    pro, zpro, _ = snap.rprofs[stpro]
    plt.figure(figsize=(3, 6))
    plt.plot(pro, zpro)
    plt.xlabel(xlabel)
    plt.ylabel('z')
    plt.savefig(filename, bbox_inches='tight')
    plt.close()


def connex_components(field, where, dsurf=None):
    nx, ny = field.shape
    tab = np.arange(nx * ny).reshape(nx, ny)

    mask = np.zeros((nx+2, ny+2))
    maskb = np.zeros((nx+2, ny+2))
    # mask is non-null in regions of interest
    mask[1:-1, 1:-1] = np.where(where, tab, 0)
    mask[0, :] = mask[-2, :]
    mask[-1, :] = mask[1, :]
    mask[:, 0] = mask[:, -2]
    mask[:, -1] = mask[:, 1]

    sweeps_x = (range(nx, 0, -1), range(1, nx+1))
    sweeps_y = (range(ny, 0, -1), range(1, ny+1))

    ii = 0
    while (mask - maskb).any():
        maskb = np.copy(mask)
        ii += 1

        # Sweeping in alternate directions for quicker convergence.
        for i in sweeps_x[ii % 2]:
            for j in sweeps_y[ii % 2]:
                # Update values to maximum among non-zero neighbours.
                if mask[i, j] > 0:
                    mask[i, j] = np.amax(mask[i-1:i+2, j-1:j+2])
        # fill-in ghost points
        mask[0, :] = mask[-2, :]
        mask[-1, :] = mask[1, :]
        mask[:, 0] = mask[:, -2]
        mask[:, -1] = mask[:, 1]

    unique_elements, counts_elements = np.unique(mask[1:-1, 1:-1],
                                                 return_counts=True)
    if dsurf is None:
        for i, elt in enumerate(unique_elements):
            mask[mask == elt] = i
        return mask, len(unique_elements) - 1
    else:
        # sort in decreasing order of size
        ind = np.argsort(counts_elements)
        unique_elements = unique_elements[ind][::-1]
        surf = counts_elements[ind][::-1] * dsurf
        # keep components larger than 1
        for i, elt in enumerate(unique_elements[surf > 1]):
            mask[mask == elt] = i
        # filter out component smaller than 1
        for i, elt in enumerate(unique_elements[surf <= 1]):
            mask[mask == elt] = 0
        unique_elements = unique_elements[counts_elements > 1]
        surf = surf[surf > 1]
        npoly = len(unique_elements) - 1
        return mask, npoly, np.sqrt(np.median(surf))


def process_last_snap(folder, draw_plots=False):
    stem = folder.name.replace('sp_', '').replace('.0', '')
    sdat = StagyyData(folder)
    ra = sdat.par['refstate']['ra0']
    Eeta = sdat.par['viscosity']['e_eta']
    if not np.isscalar(Eeta):
        Eeta = Eeta[0]

    nz = sdat.par['geometry']['nztot']
    snap = sdat.snaps[-1]
    isnap = snap.isnap
    temp_scale = np.amax(snap.rprofs['Tmean'].values)
    ra_eff = snap.timeinfo['Raeff'] * temp_scale
    eta_surf = np.exp(Eeta / 2)
    rvisc = eta_surf / snap.timeinfo['etamin']
    vhrms = snap.rprofs['vhrms'].values

    topofile = pathlib.Path(f'topo_{stem}_{isnap:02d}.npy')
    if topofile.exists():
        print('reading file ', topofile)
        topo = np.load(topofile)
    else:
        print('computing topography')
        topo = compute_topo(snap)
        np.save(topofile, topo)
    topo_amp = np.amax(topo) - np.amin(topo)
    topo_rms = np.std(topo)
    # remove polygones smaller than the box size through dsurf
    mask_hot, n_poly, l_med = connex_components(topo, topo > 0,
                                                dsurf=16**2 / topo.size)

    tempfile = pathlib.Path(f'temp_{stem}_{isnap:02d}.npy')
    if tempfile.exists():
        print('reading file ', tempfile)
        temp = np.load(tempfile)
    else:
        print('extracting temperature')
        temp = snap.fields['T'][..., nz//2, 0]
        np.save(tempfile, temp)
    cold_temp = (np.amin(temp) + np.mean(temp)) / 2
    mask_cold, n_plumes = connex_components(temp, temp < cold_temp)

    # collect results in single line
    line = [snap.isnap, ra, Eeta, ra_eff, rvisc,
            n_plumes, n_poly, vhrms[-1] / np.amax(vhrms),
            vhrms[-1], l_med, topo_amp, topo_rms]

    if draw_plots:
        plot_profile(snap, 'vhrms', r'$RMS(v_h)$', f"vh__{stem}.pdf")
        fname = f"{{}}_{stem}.pdf"
        mask_map(mask_hot[1:-1, 1:-1], snap.geom, fname.format("mask_hot"))
        mask_map(mask_cold[1:-1, 1:-1], snap.geom, fname.format("mask_cold"))
        surf_map(topo, snap, fname.format("topo"), CMAP_TOPO, "Topography")
        surf_map(temp - np.mean(temp), snap, fname.format("temp"), CMAP_TEMP,
                 "Temperature anom.")

    return line


def regimes(data):
    """Boolean arrays representing the various regimes.

    Return (stagnant, sluggish, low_viscosity, polygons).
    """
    stagnant = data['R_vh'] < 0.1
    sluggish = (0.1 <= data['R_vh']) & (data['R_vh'] < 0.9)
    others = ~(stagnant | sluggish)
    low_visc = others & (data['n_plumes'] > 90)
    polygonal = others & (~low_visc)
    return stagnant, sluggish, low_visc, polygonal


def compute_global_diagnostics(tregime, draw_plots=False):
    data = []
    for case in CASES:
        print('treating', case)
        folder = ROOT / case
        data.append(process_last_snap(folder, draw_plots))
        data[-1].append(tregime.get(case, np.nan))

    data = pd.DataFrame(data, index=CASES,
                        columns=["snapshot", "Ra", "E_eta", "Ra_eff", "R_visc",
                                 "n_plumes", "n_hot", "R_vh", "vh_rms",
                                 "L_med", "topo_amp", "topo_rms", "time_poly"])
    data.to_csv(DIAG_FILE)
    return data


def latex_table(data):
    texfile = pathlib.Path("table.tex")
    if texfile.exists():
        print('Tex file ', texfile, ' exists. Skipping')
        return
    data = data.copy(deep=True)
    data["snapshot"] = range(len(data.index))  # will be case # instead
    stag, slug, _, _ = regimes(data)
    stagslug_cases = stag | slug
    data.loc[stagslug_cases, 'n_hot'] = np.nan
    data.loc[stagslug_cases, 'L_med'] = np.nan

    header = [r"Case \#", "Ra", r"E_{\eta}", r"Ra_{eff}",
              r"\eta_{\max}/\eta_{\min}", "N_{plu}", "N_{hot}", "R_v", "v_h",
              "L_{med}", r"\Delta h", "RMS(h)", r"\tau_p"]
    header[1:] = map(lambda s: f"${s}$", header[1:])
    formatters = {
        "Ra": lambda ra: "$10^{{{}}}$".format(int(np.log10(ra))),
        "Ra_eff": sci_format_dollars,
        "topo_amp": sci_format_dollars,
        "topo_rms": sci_format_dollars,
        "time_poly": sci_format_dollars,
        "n_hot": lambda n: str(int(n)),
    }
    data.to_latex(texfile, header=header, index=False, na_rep='-',
                  formatters=formatters, float_format="%.2f", escape=False)


def plot_last_snaps():
    for case in CASES:
        print('treating', case)
        folder = ROOT / case

        stem = folder.name.replace('sp_', '').replace('.0', '')
        sdat = StagyyData(folder)
        Eeta = sdat.par['viscosity']['e_eta']
        if not np.isscalar(Eeta):
            Eeta = Eeta[0]

        nz = sdat.par['geometry']['nztot']
        snap = sdat.snaps[-1]
        isnap = snap.isnap
        fname = f"{{}}_{stem}.pdf"

        topofile = pathlib.Path(f'topo_{stem}_{isnap:02d}.npy')
        if topofile.exists():
            print('reading file ', topofile)
            topo = np.load(topofile)
        else:
            print('computing topography')
            topo = compute_topo(snap)
            np.save(topofile, topo)
        filename = pathlib.Path(fname.format("topo"))
        if not filename.exists():
            surf_map(topo, snap, filename, CMAP_TOPO, "Topography")

        tempfile = pathlib.Path(f'temp_{stem}_{isnap:02d}.npy')
        if tempfile.exists():
            print('reading file ', tempfile)
            temp = np.load(tempfile)
        else:
            print('extracting temperature')
            temp = snap.fields['T'][..., nz//2, 0]
            np.save(tempfile, temp)
        filename = pathlib.Path(fname.format("temp"))
        if not filename.exists():
            surf_map(temp - np.mean(temp), snap, filename, CMAP_TEMP,
                     "Temperature anom.")

        vhrmsdat = pathlib.Path(f'vhrms_{stem}.npy')
        if vhrmsdat.exists():
            vhrms = np.load(vhrmsdat)
        else:
            vhrms = snap.rprofs['vhrms'].values
            np.save(vhrmsdat, vhrms)

        plot_profile(snap, 'vhrms', r'$RMS(v_h)$', f"vh__{stem}.pdf")


def latex_longtable():
    texfile = pathlib.Path("allsnaps.tex")
    if texfile.exists():
        print('Tex file ', texfile, ' exists. Skipping.')
        return
    header = dedent(r"""
    \begin{longtable}{|p{1cm}|l|l|l|}
    \caption{Diagnostics figures for all cases.}
    \label{tab:all_cases}\\
    \hline
    case\#&Topography&Mid{-}depth temperature&$RMS(v_h)$\\
    \hline
    \endhead
    \hline
    \multicolumn{4}{r}{Continued on Next Page}\\
    \hline
    \endfoot
    \hline
    \endlastfoot
    """)
    graphic = r"\includegraphics[width=0.{}\textwidth]{{{}_{}.{}}}"
    with open(texfile, "w") as texfile:
        texfile.write(header)
        for i, case in enumerate(CASES):
            line = [str(i)]
            stem = case.replace("sp_", "").replace(".0", "")
            line.append(graphic.format(35, "topo", stem, "pdf"))
            line.append(graphic.format(35, "temp", stem, "pdf"))
            line.append(graphic.format(17, "vh_", stem, "pdf"))
            texfile.write(" & ".join(line))
            texfile.write("\\\\\n\\hline\n")
        texfile.write("\\end{longtable}\n")


def regime_diagram(data, annotation=None):
    """Draw regime diagram."""
    pdffile = pathlib.Path('regime_diagram.pdf')
    if pdffile.exists():
        print('Figure file ', pdffile, ' exists. Skipping.')
        return

    stag, slug, low_visc, polyg = regimes(data)
    colored = low_visc | polyg

    hmin = np.amin(data.loc[colored, 'n_hot'])
    hmax = np.amax(data.loc[colored, 'n_hot'])
    cmin = np.amin(data.loc[colored, 'n_plumes'])
    cmax = np.amax(data.loc[colored, 'n_plumes'])

    fig, axis = plt.subplots(figsize=(ONE_COL, 0.7 * ONE_COL))
    MSIZE = 8
    for idx, (slc, symbol) in enumerate(((stag, '^'), (slug, 'o'),
                                         (low_visc, 's'), (polyg, 'H'))):
        cases = data.loc[slc]

        if idx > 1:
            plt.scatter(cases['Ra_eff'], cases['R_visc'],
                        c=('red' if symbol == 'H' else 'blue'),
                        s=70, marker=symbol, vmin=hmin, vmax=hmax)
        else:
            plt.scatter(cases['Ra_eff'], cases['R_visc'], c='k', s=70,
                        cmap='Blues', marker=symbol, vmin=cmin, vmax=cmax)
        if symbol == 's' or symbol == 'H':
            for i, row in cases.iterrows():
                plt.text(.9*row['Ra_eff'], row['R_visc'],
                         str(int(row['n_plumes'])),
                         ha='right', va='center', fontsize=6)
                plt.text(1.1*row['Ra_eff'], row['R_visc'],
                         str(int(row['n_hot'])),
                         ha='left', va='center', fontsize=6)
    plt.loglog()
    plt.xlabel(r'$\mathrm{Ra_{eff}}$')
    plt.ylabel(r'$\eta_{\max} / \eta_{\min}$')
    handles = [
        mlines.Line2D([], [], color='black', linestyle='', marker='^',
                      markersize=MSIZE, label='Stagnant lid'),
        mlines.Line2D([], [], color='black', linestyle='', marker='o',
                      markersize=MSIZE, label='Sluggish lid'),
        mlines.Line2D([], [], color='red', linestyle='', marker='H',
                      markersize=MSIZE, label='Polygons'),
        mlines.Line2D([], [], color='blue', linestyle='', marker='s',
                      markersize=MSIZE, label=r'Low $\eta$ contrast'),
    ]
    plt.legend(handles=handles, bbox_to_anchor=(0., 1.02, 1., .102),
               loc='lower left', ncol=2, mode="expand", borderaxespad=0.)
    if annotation is not None:
        axis.text(0.01, 0.99, annotation, ha='left', va='top',
                  transform=axis.transAxes, fontsize=8, weight='bold')
    plt.savefig(pdffile, bbox_inches='tight')
    plt.close(fig)


def summary_constraints(data):
    """Draw various plots with observational constraints."""
    pdffile = pathlib.Path('summary_constraints.pdf')
    if pdffile.exists():
        print('Figure file ', pdffile, ' exists. Skipping.')
        return
    polyg = regimes(data)[-1]
    # plot vh vs Lpoly for Polygonal cases
    fig, ax = plt.subplots(2, 2)
    fig.set_figwidth(0.9 * TWO_COL)
    vmin = 6.2 - 1.4
    vmax = 19.9 + 11.7
    Lmin = np.sqrt(60)
    Lmax = np.sqrt(10574)
    Lmed = np.sqrt(707.5)

    # range of acceptable ice thickness, m
    dmin = 1e3
    dmax = 1e4
    drange = np.arange(dmin, dmax, 100)

    coefRa = DENSITY * ALPHA * G * DELTA_TEMP / KAPPA

    Lr = np.arange(1, 200, 1) * 1e3  # range for Lpol
    for name, case in data.loc[polyg].iterrows():
        visco = coefRa * drange**3 * case['R_visc'] / case['Ra_eff']
        depth = np.array([Lmin, Lmed, Lmax]) / case['L_med']  # in km
        ddim = Lr / case['L_med']  # range of thicknesses, in m
        inrange = (ddim >= dmin) & (ddim <= dmax)
        # min, med and max surface velocity
        vhdim = KAPPA / depth[1] * case['vh_rms'] * 0.1 * YEAR
        idx = CASES.index(name)
        lab = (r'case# ' + f'{idx}' + r': $\mathrm{Ra_{eff}} = ' +
               sci_format(case['Ra_eff']) +
               r'; \eta_{\mathrm{max}} / \eta_{\mathrm{min}} = ' +
               sci_format(case['R_visc']) + '; d = ' + sci_format(depth[1]) +
               r' \mathrm{km}; v_h = ' + sci_format(vhdim) +
               r' \mathrm{cm/yr}$')
        # panel a
        p0 = ax[0, 0].semilogy(drange*1e-3, visco, label=lab)
        vh = KAPPA * case['vh_rms'] * case['L_med'] / Lr * 100 * YEAR
        # panel b
        ax[0, 1].plot(Lr[inrange]*1e-3, vh[inrange], color=p0[0].get_color())
        # define range of depth to get acceptable velocity and size
        inrange &= (Lr <= Lmax*1e3) & (Lr >= Lmin*1e3)
        # panel c
        ax[1, 0].semilogy(Lr[inrange]*1e-3,
                          ddim[inrange] * case['topo_amp'],
                          color=p0[0].get_color())
        y1 = ddim[inrange]**2 / KAPPA * case['time_poly'] / YEAR
        print('time in range {:.2e} to {:.2e}'.format(np.amin(y1)*1e-3,
                                                      np.amax(y1)*1e-3))
        y2 = y1 + 5e5
        # panel d
        ax[1, 1].fill_between(ddim[inrange]*1e-3, y1, y2,
                              color=p0[0].get_color(), alpha=0.1)
        ax[1, 1].semilogy(ddim[inrange]*1e-3,
                          ddim[inrange]**2 / KAPPA * case['time_poly'] / YEAR,
                          color=p0[0].get_color())
    tsize = 8
    # panel a
    ax[0, 0].set_xlabel('Layer depth (km)')
    ax[0, 0].set_ylabel('Surface viscosity (Pa s)')
    ax[0, 0].text(0.01, 0.99, 'a', ha='left', va='top',
                  fontsize=tsize, transform=ax[0, 0].transAxes, weight='bold')
    # Add min bound from Buhler and Ingersoll (2018)
    ax[0, 0].semilogy([1, 10], [1e16, 1e16], '--', c='k')
    arrow = mpatches.FancyArrowPatch((1, 0.9e16), (1, 5e16),
                                     arrowstyle='-|>', mutation_scale=5)
    ax[0, 0].add_patch(arrow)
    ax[0, 0].text(1.1, 2e16, 'B&I, 2018', va='center')
    # ax[0, 0].tick_params(labelsize=tsize)

    # panel b
    # range of polygonal sizes from White et al (2017)
    rect = Rectangle((Lmin, 0), Lmax-Lmin, 100, color='blue', alpha=0.1,
                     lw=None)
    ax[0, 1].add_patch(rect)
    vh = 50
    arrow = mpatches.FancyArrowPatch((Lmin, vh), (Lmax, vh),
                                     arrowstyle='|-|', mutation_scale=5)
    ax[0, 1].add_patch(arrow)
    ax[0, 1].text((Lmin+Lmax)/2, vh, 'range from White et al (2017)',
                  ha='center', va='bottom')
    # median polygonal size
    ax[0, 1].plot([Lmed, Lmed], [0, 100], '--',
                  label='Median value from White et al (2017)')

    ax[0, 1].set_xlabel('Cell size (km)')
    ax[0, 1].set_ylabel('Horiz. velocity (cm/yr)')
    # ax[0, 1].tick_params(labelsize=tsize)
    ax[0, 1].text(0.99, 0.99, 'b', ha='right', va='top',
                  fontsize=tsize, transform=ax[0, 1].transAxes, weight='bold')

    # Annotate with data from Buhler and Ingersoll (2018)
    DataBuhler = [
        (998, 8.7, 19.8),
        (659, 5, 11.3),
        (1184, 9.6, 18.8),
        (826, 9.5, 26.8),
        (275, 4.8, 12.8),
        (160, 5, 12.8),
        (678, 9, 17.4),
    ]
    for area, vmin, vmax in DataBuhler:
        csize = np.sqrt(area)
        arrow = mpatches.FancyArrowPatch((csize, vmin), (csize, vmax),
                                         arrowstyle='|-|', mutation_scale=5)
        ax[0, 1].add_patch(arrow)
    ax[0, 1].set_xlim((0, 110))
    ax[0, 1].set_ylim((0, 57))

    # panel c
    ax[1, 0].set_xlabel('Cell size (km)')
    ax[1, 0].set_ylabel('Topography ampl. (m)')
    ax[1, 0].text(0.01, 0.99, 'c', ha='left', va='top',
                  fontsize=tsize, transform=ax[1, 0].transAxes, weight='bold')
    # add range and median polygonal sizes
    tomin, tomax = ax[1, 0].get_ylim()
    rect = Rectangle((Lmin, tomin), Lmax-Lmin, tomax, color='blue', alpha=0.1,
                     lw=None)
    ax[1, 0].add_patch(rect)
    # median polygonal size
    ax[1, 0].plot([Lmed, Lmed], [tomin, tomax], '--',
                  label='Median value from White et al (2017)')
    ax[1, 0].set_ylim(tomin, tomax)

    # annotate with Schenk
    LrSchenk = 24  # width of the cell in Schenk et al, LPSC 2018, in km
    topoSchenk = 200  # amplitude of topo in Schenk et al, LPSC 2018, in m
    arrow2 = mpatches.FancyArrowPatch((LrSchenk, topoSchenk),
                                      (LrSchenk, 2*topoSchenk),
                                      arrowstyle='-|>', mutation_scale=5)
    ax[1, 0].add_patch(arrow2)
    ax[1, 0].semilogy(LrSchenk, topoSchenk, 'o', c='k')
    ax[1, 0].text(LrSchenk+4, topoSchenk, r'Schenk et al, 2018',
                  fontsize=tsize-2, va='center')

    # ax[1, 0].tick_params(labelsize=tsize)
    ax[1, 0].set_xlim((0, 110))

    # panel d

    ax[1, 1].set_ylabel('Time for polygonal regime (yr)')
    ax[1, 1].set_xlabel('Layer depth (km)')
    ax[1, 1].tick_params(labelsize=tsize)
    ax[1, 1].text(0.99, 0.99, 'd', ha='right', va='top',
                  fontsize=tsize, transform=ax[1, 1].transAxes, weight='bold')

    # general legend
    ax[0, 0].legend(bbox_to_anchor=(0., 1.02, 2, 0.2), loc='lower left',
                    mode="expand", borderaxespad=0.)
    # save
    plt.tight_layout()
    plt.savefig(pdffile)
    plt.close(fig)


def summary_constraints_decomp(data):
    """Draw various plots with observational constraints."""
    polyg = regimes(data)[-1]
    # plot vh vs Lpoly for Polygonal cases

    fig3a, ax3a = plt.subplots()
    fig3bc, ax3bc = plt.subplots(2, 1, sharex=True, squeeze=False,
                                 figsize=(5, 8))
    fig3d, ax3d = plt.subplots()

    vmin = 6.2 - 1.4
    vmax = 19.9 + 11.7
    Lmin = np.sqrt(60)
    Lmax = np.sqrt(10574)
    Lmed = np.sqrt(707.5)

    # range of acceptable ice thickness, m
    dmin = 1e3
    dmax = 1e4
    drange = np.arange(dmin, dmax, 100)

    coefRa = DENSITY * ALPHA * G * DELTA_TEMP / KAPPA

    Lr = np.arange(1, 200, 1) * 1e3  # range for Lpol
    for name, case in data.loc[polyg].iterrows():
        visco = coefRa * drange**3 * case['R_visc'] / case['Ra_eff']
        depth = np.array([Lmin, Lmed, Lmax]) / case['L_med']  # in km
        ddim = Lr / case['L_med']  # range of thicknesses, in m
        inrange = (ddim >= dmin) & (ddim <= dmax)
        # min, med and max surface velocity
        vhdim = KAPPA / depth[1] * case['vh_rms'] * 0.1 * YEAR
        idx = CASES.index(name)
        lab = (r'case# ' + f'{idx}' + r': $\mathrm{Ra_{eff}} = ' +
               sci_format(case['Ra_eff']) +
               r'; \eta_{\mathrm{max}} / \eta_{\mathrm{min}} = ' +
               sci_format(case['R_visc']) + '; d = ' + sci_format(depth[1]) +
               r' \mathrm{km}; v_h = ' + sci_format(vhdim) +
               r' \mathrm{cm/yr}$')
        # panel a
        p0 = ax3a.semilogy(drange*1e-3, visco, label=lab)
        vh = KAPPA * case['vh_rms'] * case['L_med'] / Lr * 100 * YEAR
        # panel b
        ax3bc[1, 0].plot(Lr[inrange]*1e-3, vh[inrange],
                         color=p0[0].get_color())
        # define range of depth to get acceptable velocity and size
        inrange &= (Lr <= Lmax*1e3) & (Lr >= Lmin*1e3)
        # panel c
        ax3bc[0, 0].semilogy(Lr[inrange]*1e-3,
                             ddim[inrange] * case['topo_amp'],
                             color=p0[0].get_color())
        y1 = ddim[inrange]**2 / KAPPA * case['time_poly'] / YEAR
        print('time in range {:.2e} to {:.2e}'.format(np.amin(y1)*1e-3,
                                                      np.amax(y1)*1e-3))
        y2 = y1 + 5e5
        # panel d
        ax3d.fill_between(ddim[inrange]*1e-3, y1, y2,
                          color=p0[0].get_color(), alpha=0.1)
        ax3d.semilogy(ddim[inrange]*1e-3,
                      ddim[inrange]**2 / KAPPA * case['time_poly'] / YEAR,
                      color=p0[0].get_color())

    fontsize = 12
    tsize = 12
    # panel a
    ax3a.set_xlabel('Layer depth (km)', fontsize=fontsize)
    ax3a.set_ylabel('Surface viscosity (Pa s)', fontsize=fontsize)
    # Add min bound from Buhler and Ingersoll (2018)
    ax3a.semilogy([1, 10], [1e16, 1e16], '--', c='k')
    arrow = mpatches.FancyArrowPatch((1, 0.9e16), (1, 5e16),
                                     arrowstyle='-|>', mutation_scale=5)
    ax3a.add_patch(arrow)
    ax3a.text(1.1, 2e16, 'Buhler & Ingersoll, 2018', fontsize=tsize-2,
              va='center')
    ax3a.tick_params(labelsize=tsize)

    # panel b
    # range of polygonal sizes from White et al (2017)
    rect = Rectangle((Lmin, 0), Lmax-Lmin, 100, color='blue', alpha=0.1,
                     lw=None)
    ax3bc[1, 0].add_patch(rect)
    vh = 50
    arrow = mpatches.FancyArrowPatch((Lmin, vh), (Lmax, vh),
                                     arrowstyle='|-|', mutation_scale=5)
    ax3bc[1, 0].add_patch(arrow)
    ax3bc[1, 0].text((Lmin+Lmax)/2, vh, 'range from White et al (2017)',
                     ha='center', va='bottom')
    # median polygonal size
    ax3bc[1, 0].plot([Lmed, Lmed], [0, 100], '--',
                     label='Median value from White et al (2017)')

    ax3bc[1, 0].set_xlabel('Cell size (km)', fontsize=fontsize)
    ax3bc[1, 0].set_ylabel('Horiz. velocity (cm/yr)', fontsize=fontsize)
    ax3bc[1, 0].tick_params(labelsize=tsize)

    # Annotate with data from Buhler and Ingersoll (2018)
    DataBuhler = [
        (998, 8.7, 19.8),
        (659, 5, 11.3),
        (1184, 9.6, 18.8),
        (826, 9.5, 26.8),
        (275, 4.8, 12.8),
        (160, 5, 12.8),
        (678, 9, 17.4),
    ]
    for area, vmin, vmax in DataBuhler:
        csize = np.sqrt(area)
        arrow = mpatches.FancyArrowPatch((csize, vmin), (csize, vmax),
                                         arrowstyle='|-|', mutation_scale=5)
        ax3bc[1, 0].add_patch(arrow)
    ax3bc[1, 0].set_xlim((0, 110))
    ax3bc[1, 0].set_ylim((0, 57))

    # panel c
    ax3bc[0, 0].set_ylabel('Topography ampl. (m)', fontsize=fontsize)
    # add range and median polygonal sizes
    tomin, tomax = ax3bc[0, 0].get_ylim()
    rect = Rectangle((Lmin, tomin), Lmax-Lmin, tomax, color='blue', alpha=0.1,
                     lw=None)
    ax3bc[0, 0].add_patch(rect)
    # median polygonal size
    ax3bc[0, 0].plot([Lmed, Lmed], [tomin, tomax], '--',
                     label='Median value from White et al (2017)')
    ax3bc[0, 0].set_ylim(tomin, tomax)

    # annotate with Schenk
    LrSchenk = 24  # width of the cell in Schenk et al, LPSC 2018, in km
    topoSchenk = 200  # amplitude of topo in Schenk et al, LPSC 2018, in m
    arrow2 = mpatches.FancyArrowPatch((LrSchenk, topoSchenk),
                                      (LrSchenk, 2*topoSchenk),
                                      arrowstyle='-|>', mutation_scale=5)
    ax3bc[0, 0].add_patch(arrow2)
    ax3bc[0, 0].semilogy(LrSchenk, topoSchenk, 'o', c='k')
    ax3bc[0, 0].text(LrSchenk+2, topoSchenk, r'Schenk et al, 2018',
                     fontsize=tsize-2, va='center')

    ax3bc[0, 0].tick_params(labelsize=tsize)
    ax3bc[0, 0].set_xlim((0, 110))

    # panel d

    ax3d.set_ylabel('Time for polygonal regime (yr)', fontsize=fontsize)
    ax3d.set_xlabel('Layer depth (km)', fontsize=fontsize)
    ax3d.tick_params(labelsize=tsize)

    for fig, name in zip([fig3a, fig3bc, fig3d], ['fig3a', 'fig3bc', 'fig3d']):
        plt.tight_layout()
        fig.savefig(name + '.pdf', bbox_inches='tight')
        plt.close(fig)


def temps_mid_depth_dim():
    """Map temperature fields from different regimes side by side."""
    pdffile = pathlib.Path("temps_mid_depth_dim.pdf")
    if pdffile.exists():
        print('Figure file ', pdffile, ' exists. Skipping.')
        return
    fields = []
    cases = [
        "sp_ra1e+07__c_eta10.0",
        "sp_ra1e+07__c_eta1000.0",
        "sp_ra1e+07__c_eta10000.0",
    ]
    for case in cases:
        stem = case.replace('sp_', '').replace('.0', '')
        case_folder = ROOT / case
        sdat = StagyyData(case_folder)
        snap = sdat.snaps[-1]
        geom = snap.geom
        isnap = snap.isnap
        tempfile = pathlib.Path(f'temp_{stem}_{isnap:02d}.npy')
        if tempfile.exists():
            print('reading file ', tempfile)
            ttp = np.load(tempfile)
        else:
            print('extracting temperature')
            ttp = snap.fields['T'][..., geom.nztot // 2, 0]
            np.save(tempfile, ttp)
        ttp = ttp - np.mean(ttp)
        ttp *= DELTA_TEMP
        fields.append(ttp)

    xcoord = geom.x_centers * DEPTH * 1e-3
    ycoord = geom.y_centers * DEPTH * 1e-3
    # global min and max
    vmin = min(np.amin(fld) for fld in fields)
    vmax = max(np.amax(fld) for fld in fields)
    nplots = len(cases)
    fig, axes = plt.subplots(ncols=nplots, sharey=True)
    fig.set_figwidth(TWO_COL)

    for (fld, axis, letter) in zip(fields, axes, abcd):
        pcm = axis.pcolormesh(xcoord, ycoord, fld, cmap=CMAP_TEMP,
                              shading='gouraud',
                              norm=MidpointNormalize(vmin, vmax, 0.),
                              rasterized=True)
        axis.axis('square')
        axis.set_xlabel('x (km)')
        axis.text(0.01, 0.99, letter, ha='left', va='top',
                  transform=axis.transAxes, weight='bold')
    axes[0].set_ylabel('y (km)')

    fig.colorbar(pcm, ax=axes, shrink=0.6, label="Temperature anom. (K)",
                 location='top')
    plt.savefig(pdffile, bbox_inches='tight', dpi=600)
    plt.close()


def map_topo_dim(case, plot_vel=True, annotation=None, istart=-1):
    stem = case.replace('sp_', '').replace('.0', '')
    sdat = StagyyData(ROOT / case)
    annot_time = False
    for snap in sdat.snaps[istart:]:
        isnap = snap.isnap
        pdffile = pathlib.Path(f"topo_{stem}_{isnap:02d}.pdf")
        if pdffile.exists():
            print('figure file ', pdffile, 'exists, skipping')
            continue
        print('snapshot # =', isnap)
        topofile = pathlib.Path(f'topo_{stem}_{isnap:02d}.npy')
        if topofile.exists():
            print('reading file ', topofile)
            topo = np.load(topofile)
        else:
            print('computing topography')
            topo = compute_topo(snap)
            np.save(topofile, topo)
        # topo files saved dimensionless
        topo = topo * DEPTH
        if plot_vel:
            velfile = pathlib.Path(f'surfvel_{stem}_{isnap:02d}.npz')
            step = 64
            if velfile.exists():
                print('reading file ', velfile)
                vel = np.load(velfile)
                vel_x = vel['vel1'][::step, ::step]
                vel_y = vel['vel2'][::step, ::step]
            else:
                vx = snap.fields['v2'][:-1, :-1, -1, 0]
                vy = snap.fields['v1'][:-1, :-1, -1, 0]
                np.savez(velfile, vel1=vx, vel2=vy)
                vel_x = vx[::step, ::step]
                vel_y = vy[::step, ::step]
            # make it dimensional
            vel_x *= KAPPA / DEPTH * 100 * YEAR
            vel_y *= KAPPA / DEPTH * 100 * YEAR
        else:
            vel_x = None
            vel_y = None

        if annotation == 'time' or annot_time:
            annot_time = True
            time = snap.timeinfo['t']
            time *= DEPTH * DEPTH / KAPPA / YEAR * 1e-3
            unit_t = ' (kyr)$'
            annotation = r'$t = '+sci_format(time)+unit_t
            print(annotation)
        surf_map(topo, snap, pdffile, CMAP_TOPO,
                 "Topography (m)", dimensional=True, velx=vel_x, vely=vel_y,
                 annotation=annotation)


def map_topo_T(case, dimensional=False, istart=0):
    if case[:2] == 'SP':
        stem = case.replace('SP_Timevar/sp_', 'TV_').replace('.0', '')
    else:
        stem = case.replace('sp_', '').replace('.0', '')
    sdat = StagyyData(ROOT / case)
    nz = sdat.par['geometry']['nztot']
    isnap_before = None
    for snap in sdat.snaps[istart:]:
        del sdat.snaps[isnap_before]
        gc.collect()
        isnap_before = snap.isnap
        isnap = snap.isnap
        pdffile = pathlib.Path(f"topo_T_{stem}_{isnap:02d}.pdf")
        if pdffile.exists():
            print('Figure file ', pdffile, ' exists. Skipping.')
            continue
        time = snap.timeinfo['t']
        topofile = pathlib.Path(f'topo_{stem}_{isnap:02d}.npy')
        if topofile.exists():
            print('reading file ', topofile)
            topo = np.load(topofile)
        else:
            print('computing topography')
            topo = compute_topo(snap)
            np.save(topofile, topo)
        tempfile = pathlib.Path(f'temp_{stem}_{isnap:02d}.npy')
        if tempfile.exists():
            print('reading file ', tempfile)
            temp = np.load(tempfile)
        else:
            print('extracting temperature')
            temp = snap.fields['T'][..., nz//2, 0]
            np.save(tempfile, temp)
        temp = snap.fields['T'][..., nz//2, 0]
        temp -= np.mean(temp)
        if dimensional:
            topo *= DEPTH
            temp *= DELTA_TEMP
            time *= DEPTH * DEPTH / KAPPA / YEAR * 1e-3
            unit_topo = " (m)"
            unit_temp = " (K)"
        else:
            unit_topo = ""
            unit_temp = ""
        surf_map2((topo, temp), snap, pdffile,
                  (CMAP_TOPO, CMAP_TEMP),
                  ("Topography"+unit_topo, "Temperature anom."+unit_temp),
                  dimensional=dimensional, time=time)


def guide1(ra):
    return 0.1 * (ra/2e6) ** (-0.25)


def guide2(ra):
    return 0.1 * (ra/1e7) ** (-1/2)


def plot_T_RaH(case, polreg):
    """Plots Mean T vs Ra_H for a case using -dTdt as effective heating rate"""
    casenum = CASES.index(case)
    print('treating', case, 'case #:', casenum)
    folder = ROOT / case
    stem = folder.name.replace('sp_', '').replace('.0', '')
    pdffile = pathlib.Path(f'T_RaH_{stem}.pdf')
    if pdffile.exists():
        print('Figure file ', pdffile, ' exists. Skipping.')
        return
    datafile = pathlib.Path(f'T_RaH_{stem}.npz')
    sdat = StagyyData(folder)
    if datafile.exists():
        data = np.load(datafile)
        tmean = data['Tmean']
        time = data['time']
        dtdt = data['dTdt']
        Raeff = data['Raeff']
    else:
        # mean temperature vector
        tmean = sdat.tseries['Tmean'].values[:-1]
        # and its time variation rate
        dtdt, time, _ = sdat.tseries['dTdt']
        # Effective Rayleigh number
        Raeff = sdat.tseries['Raeff'].values[:-1]
        # save data
        np.savez(datafile, Tmean=tmean, time=time, dTdt=dtdt, Raeff=Raeff)

    # rescaled mean temperature
    theta = - tmean / dtdt
    # rescaled Ra
    Raeff = Raeff / theta
    plt.loglog(Raeff, theta)
    snap_pol = polreg.at[case, 'firstsnap']
    print('First polygonal snapshot =', snap_pol)
    for step in sdat.snaps.filter(func=lambda s: s.isnap % 2 == 0
                                  or s.isnap == snap_pol or s.isnap == 1):
        snap_time = step.timeinfo['t']
        itime = np.searchsorted(time, snap_time)
        if itime >= len(time) or not np.isclose(time[itime], snap_time):
            continue
        isnap = step.isnap
        ccolor = "darkred" if isnap >= snap_pol else "black"
        plt.text(Raeff[itime], theta[itime], str(isnap),
                 ha="center", va="center",
                 bbox={"boxstyle": "circle", "color": ccolor},
                 fontdict={'color':  'white', 'weight': 'bold', 'size': 3})
    # add eye-guides
    rax = np.array([1e6, 1e7])
    Tx = guide1(rax)
    plt.loglog(rax, Tx, '--', label=r'$\theta = 0.1 (Ra_H/(2\ 10^6))^{-1/4}$')

    rax2 = np.array([1e7, 4e7])
    Tx2 = guide2(rax2)
    plt.loglog(rax2, Tx2, '--', label=r'$\theta = 0.1 (Ra_H/10^7)^{-1/2}$')

    plt.xlabel(r'$Ra_H$')
    plt.ylabel(r'$\theta$')
    plt.legend(bbox_to_anchor=(0., 1.02, 1., .102),
               ncol=2, mode="expand", borderaxespad=0.)
    plt.savefig(pdffile, bbox_inches='tight')
    plt.close()


def polyg_regime(snap):
    """Determine whether given snapshot is in polygonal regime."""
    nz = snap.geom.nztot
    vhrms = snap.rprofs['vhrms'].values
    R_vh = vhrms[-1] / np.amax(vhrms)
    temp = snap.fields['T'][..., nz//2, 0]
    cold_temp = (np.amin(temp) + np.mean(temp)) / 2
    n_plumes = connex_components(temp, temp < cold_temp)[1]
    data = {"R_vh": R_vh, "n_plumes": n_plumes}
    return regimes(data)[3] != 0


def first_snap_polyg(case):
    """Find the first polygonal regime for case"""
    folder = ROOT / case
    sdat = StagyyData(folder)
    isnap_before = None
    for step in sdat.snaps:
        del sdat.snaps[isnap_before]
        gc.collect()
        isnap_before = step.isnap
        if polyg_regime(step):
            pol_time = step.timeinfo['t']
            break
    print('First polygonal regime for snapshot #', step.isnap, pol_time)
    return {"firstsnap": step.isnap, "poltime": pol_time}


def compute_polyg_reg():
    polyg_reg = []
    for case in POLYG_CASES:
        polyg_reg.append(first_snap_polyg(case))
    polyg_reg = pd.DataFrame(polyg_reg, index=POLYG_CASES)
    polyg_reg.to_csv(POLYG_FILE)
    return polyg_reg


def time_series(case, plot_qbot=False):
    folder = ROOT / case
    stem = folder.name.replace('sp_', '').replace('.0', '')
    pdffile = pathlib.Path(f'time_{stem}.pdf')
    if pdffile.exists():
        print('Figure file ', pdffile, ' exists. Skipping.')
        return

    sdat = StagyyData(folder)

    Eeta = sdat.par['viscosity']['e_eta']
    if not np.isscalar(Eeta):
        Eeta = Eeta[0]
    eta_surf = np.exp(Eeta / 2)

    datafile = pathlib.Path(f'time_{stem}.npz')
    if datafile.exists():
        print('reading time data from ', datafile)
        timedata = np.load(datafile)
        time = timedata['time']
        vrms = timedata['vrms']
        qtop = timedata['qtop']
        Tmean = timedata['Tmean']
        raeff = timedata['raeff']
        rvisc = timedata['rvisc']
        if plot_qbot:
            qbot = timedata['qbot']
    else:
        print('extracting time data for ', case)
        time = sdat.tseries['t'].values
        vrms = sdat.tseries['vrms'].values
        qtop = sdat.tseries['ftop'].values
        if plot_qbot:
            qbot = sdat.tseries['fbot'].values
        Tmean = sdat.tseries['Tmean'].values
        raeff = sdat.tseries['Raeff'].values * sdat.tseries['Tmean'].values
        rvisc = eta_surf / sdat.tseries['etamin'].values
        if plot_qbot:
            np.savez(datafile, time=time, vrms=vrms, qtop=qtop, Tmean=Tmean,
                     raeff=raeff, rvisc=rvisc, qbot=qbot)
        else:
            np.savez(datafile, time=time, vrms=vrms, qtop=qtop, Tmean=Tmean,
                     raeff=raeff, rvisc=rvisc)
    # give dimensions
    time *= dtime
    vrms *= dvel
    qtop *= dq
    Tmean = Tmean * DELTA_TEMP + Tinf
    if plot_qbot:
        qbot *= dq
    # Now plot
    fig, axes = plt.subplots(5, 1, sharex=True)

    # Surface heat flux
    axes[0].plot(time, qtop, label=r'$q_{top}$')
    if plot_qbot:
        axes[0].plot(time, qbot, label=r'$q_{bot}$')
        axes[0].legend(loc='upper right')
        qlabel = r'$q$ ' + unit_q
    else:
        qlabel = r'$q_{surf}$ ' + unit_q
    axes[0].set_ylabel(qlabel)
    axes[0].text(0.01, 0.99, 'a', ha='left', va='top',
                 transform=axes[0].transAxes, weight='bold')

    # Mean temperature
    axes[1].plot(time, Tmean)
    Tlabel = r'$\langle T \rangle$' + unit_T
    axes[1].set_ylabel(Tlabel)
    axes[1].text(0.01, 0.99, 'b', ha='left', va='top',
                 transform=axes[1].transAxes, weight='bold')

    # RMS velocity
    axes[2].plot(time, vrms)
    vlabel = r'$V_{rms}$' + unit_v
    axes[2].set_ylabel(vlabel)
    axes[2].text(0.01, 0.99, 'c', ha='left', va='top',
                 transform=axes[2].transAxes, weight='bold')

    # Effective Ra
    axes[3].semilogy(time, raeff)
    axes[3].set_ylabel(r'$\mathrm{Ra}_{eff}$')
    axes[3].text(0.01, 0.99, 'd', ha='left', va='top',
                 transform=axes[3].transAxes, weight='bold')

    # viscosity contrast
    axes[4].semilogy(time, rvisc)
    axes[4].set_ylabel(r'$\eta_{max}/\eta_{min}$')
    axes[4].text(0.01, 0.99, 'e', ha='left', va='top',
                 transform=axes[4].transAxes, weight='bold')

    tlabel = r'time' + unit_t
    axes[4].set_xlabel(tlabel)

    plt.savefig(pdffile, bbox_inches='tight')


def plot_Tprofs_decay():
    case = 'decay14'
    direc = 'SP_decay'
    pdffile = pathlib.Path(f'Tprofs_{case}.pdf')

    if pdffile.exists():
        print('Figure file ', pdffile, ' exists. Skipping.')
        return

    data_dir = ROOT / direc / case

    sdat = StagyyData(data_dir)
    depthfile = pathlib.Path('depth.npy')
    if depthfile.exists():
        print('reading depth data from ', depthfile)
        z = np.load(depthfile)
    else:
        print('extract depth data')
        z = (1 - sdat.snaps[-1].rprofs['r'].values)
        np.save(depthfile, z)
    # make it dimensional
    z *= DEPTH

    fig, axes = plt.subplots(1, 4, sharey=True, figsize=(12, 3))
    ylab = r'depth' + unit_d
    axes[0].set_ylabel(ylab)
    axes[0].invert_yaxis()

    xlab = r'temperature' + unit_T

    Tinf = 34.9
    Tsup = 37.6

    i = 0
    for snap in sdat.snaps[0, 3, 50, 78]:
        time = snap.timeinfo['t'] * dtime
        prof = snap.rprofs
        isnap = snap.isnap
        datafile = pathlib.Path(f'Tprofs_{case}_{isnap}.npz')
        if datafile.exists():
            print('reading Tprofs from ', datafile)
            data = np.load(datafile)
            Tmin = data['Tmin']
            Tmean = data['Tmean']
            Tmax = data['Tmax']
        else:
            print('extracting Tprofs data for snap ', isnap)
            Tmin = prof['Tmin'].values
            Tmean = prof['Tmean'].values
            Tmax = prof['Tmax'].values
            np.savez(datafile, Tmin=Tmin, Tmean=Tmean, Tmax=Tmax)
        Tmin = Tmin * DELTA_TEMP + TSURF
        Tmean = Tmean * DELTA_TEMP + TSURF
        Tmax = Tmax * DELTA_TEMP + TSURF

        axes[i].plot(Tmin, z, Tmean, z, Tmax, z)
        axes[i].set_xlabel(xlab)
        axes[i].set_xlim(Tinf, Tsup)
        anot_time = r'time = {:.1f}'.format(time) + unit_t
        axes[i].text(0.01, 0.995, anot_time, ha='left', va='top',
                     transform=axes[i].transAxes)
        i += 1

    plt.savefig(pdffile, bbox_inches='tight')


def plot_decay_time():
    case = 'decay14'
    direc = 'SP_decay'
    pdffile = pathlib.Path(f't_dT_vmrs_{case}.pdf')

    if pdffile.exists():
        print('Figure file ', pdffile, ' exists. Skipping.')
        return

    data_dir = ROOT / direc / case
    sdat = StagyyData(data_dir)
    datafile = pathlib.Path(f't_mdT_rdT_{case}.npy')
    if datafile.exists():
        print('read data from ', datafile)
        data = np.load(datafile, allow_pickle=True)
    else:
        z = np.zeros(sdat.snaps[-1].geom.nztot + 2)
        z[1:-1] = sdat.snaps[-1].rprofs['r'].values
        z[-1] = 1
        data = []
        isnap_before = None
        for snap in sdat.snaps[:]:
            if isnap_before is not None:
                del sdat.snaps[isnap_before]
            gc.collect()
            isnap_before = snap.isnap

            mdT = meandT(snap, z)
            rdT = rmsdT(snap, z)
            time = snap.timeinfo['t']
            data.append([time, mdT, rdT])
        data = np.array(data)
        np.save('t_mdT_rdT.npy', data)

    fig, axes = plt.subplots(3, 1, sharex=True)

    mintime2 = 0.01  # start time for fitting the evolution
    print('mintime2 =', mintime2 * dtime)
    time2 = data[:, 0]
    dT = data[:, 1] * dtemp
    rmsT = data[:, 2] * dtemp

    # fit and plot rms of temperature anomaly
    poptt, pcovt = curve_fit(expfun, time2[time2 > mintime2],
                             rmsT[time2 > mintime2])
    # rescale
    time2 *= dtime
    poptt[1] /= dtime
    # print('poptt =', poptt)
    axes[1].semilogy(time2, rmsT, label=r'data')
    axes[1].text(0.01, 0.99, 'b', ha='left', va='top',
                 transform=axes[1].transAxes)
    axes[1].semilogy(
        time2, expfun(time2, *poptt),
        label=r'fit: {:.1e} exp(-{:.1e} t)'.format(poptt[0], poptt[1]))
    axes[1].legend()
    rTlabel = r'RMS$(T-\overline{T}(z))$' + unit_T
    axes[1].set_ylabel(rTlabel)

    # plot dT
    axes[0].plot(time2, dT,
                 label=r'$\langle T_{max}(z) - T_{min}(z)\rangle$')
    axes[0].text(0.01, 0.99, 'a', ha='left', va='top',
                 transform=axes[0].transAxes)
    dTlabel = r'$\langle T_{max}(z) - T_{min}(z)\rangle$' + unit_T
    axes[0].set_ylabel(dTlabel)

    # fit and plot rms velocity
    # fit with dimensionless var
    mintime = 0.01
    time = sdat.tseries['t'].values
    vrms = sdat.tseries['vrms'].values * dvel
    popt, pcov = curve_fit(expfun, time[time > mintime], vrms[time > mintime])
    # then scales
    time *= dtime
    popt[1] /= dtime
    print('popt =', popt)

    axes[2].semilogy(time, vrms, label=r'data')
    axes[2].text(0.01, 0.99, 'c', ha='left', va='top',
                 transform=axes[2].transAxes)
    axes[2].semilogy(
        time, expfun(time, *popt),
        label=r'fit: {:.1e} exp(-{:.1e} t)'.format(popt[0], popt[1]))
    axes[2].legend()

    tlabel = r'time' + unit_t
    vlabel = r'$V_{rms}$' + unit_v
    axes[2].set_xlabel(tlabel)
    axes[2].set_ylabel(vlabel)

    plt.savefig(pdffile, bbox_inches='tight')


def meandT(snap, z):
    nztot = snap.sdat.par['geometry']['nztot']
    prof = snap.rprofs
    dT = np.zeros(nztot+2)
    dT[1:-1] = prof['Tmax'].values - prof['Tmin'].values
    return np.trapz(dT, z)


def rmsdT(snap, z):
    dT = snap.fields['T'][..., 0] - snap.rprofs['Tmean'].values
    return np.sqrt(np.trapz(np.multiply(dT, dT), z))


def expfun(x, a, b):
    return a * np.exp(-b * x)


def plot_timevar():
    direc = 'sp_ra1e+07__c_eta30000_FullPeriod'
    pdffile = pathlib.Path(f'time_TV_{direc}.pdf')
    if pdffile.exists():
        print('Figure file ', pdffile, ' exists. Skipping.')
        return
    data_dir = '/Volumes/LaCie/SPexplo/' + direc

    sdat = StagyyData(data_dir)
    try:
        period = sdat.par['boundaries']['sublim_period']
    except KeyError:
        period = 248 * YEARND
    period *= dtime  # in kyear

    datafile = pathlib.Path(f'time_{direc}.npz')

    if datafile.exists():
        print('reading for file ', datafile)
        data = np.load(datafile)
        time = data['time']
        vrms = data['vrms']
        qtop = data['qtop']
        Tmean = data['Tmean']
        raeff = data['raeff']
        rvisc = data['rvisc']
    else:
        time = sdat.tseries['t'].values
        vrms = sdat.tseries['vrms'].values
        qtop = sdat.tseries['Nutop'].values
        Tmean = sdat.tseries['Tmean'].values * DELTA_TEMP + Tinf
        raeff = sdat.tseries['Raeff'].values * sdat.tseries['Tmean'].values
        rvisc = sdat.tseries['etamax'].values / sdat.tseries['etamin'].values
        np.savez(datafile, time=time, vrms=vrms, qtop=qtop, Tmean=Tmean,
                 raeff=raeff, rvisc=rvisc)
    # dimensional
    time *= dtime
    vrms *= dvel
    qtop *= dq
    Tmean = Tmean * DELTA_TEMP + TSURF
    # last period
    tmax = time[-1]
    tmin = tmax - 5 * period

    fig, axes = plt.subplots(5, 1, sharex=True)

    # Surface heat flux
    qlabel = r'$q_{surf}$ ' + unit_q
    plot_ax(axes[0], time, qtop, tmin, tmax, qlabel, 'a', form='%.1f')

    # Mean temperature
    Tlabel = r'$\langle T \rangle$' + unit_T
    plot_ax(axes[1], time, Tmean, tmin, tmax, Tlabel, 'b', form='%.3f')

    # RMS velocity
    vlabel = r'$V_{rms}$' + unit_v
    plot_ax(axes[2], time, vrms, tmin, tmax, vlabel, 'c', form='%.2f')

    # Effective Ra
    plot_ax(axes[3], time, raeff, tmin, tmax, r'$\mathrm{Ra}_{eff}$', 'd',
            form='%.2e', logscale=True)

    # viscosity contrast
    plot_ax(axes[4], time, rvisc, tmin, tmax, r'$\eta_{max}/\eta_{min}$', 'e',
            form='%.2e', logscale=True)

    tlabel = r'time' + unit_t
    axes[4].set_xlabel(tlabel)

    plt.savefig(pdffile, bbox_inches='tight')


def plot_ax(ax, time, data, tmin, tmax, label, lett, form='%0.2e',
            logscale=False):
    if logscale:
        ax.semilogy(time, data)
    else:
        ax.plot(time, data)
    ax.set_ylabel(label)
    ax.text(0.01, 0.99, lett, ha='left', va='top',
            transform=ax.transAxes, weight='bold')
    # add inset
    axins = ax.inset_axes([0.8, 0.45, 0.18, 0.5])
    axins.plot(time, data)
    axins.set_xlim(xmin=tmin, xmax=tmax)
    ymin = np.amin(data[time > tmin])
    ymax = np.amax(data[time > tmin])
    ran = ymax - ymin
    axins.set_ylim(ymin-0.05*ran, ymax+0.05*ran)
    axins.yaxis.set_major_formatter(ticker.FormatStrFormatter(form))


if __name__ == "__main__":
    plot_last_snaps()

    if POLYG_FILE.exists():
        polyg_reg = pd.read_csv(POLYG_FILE, index_col=0)
    else:
        polyg_reg = compute_polyg_reg()

    if DIAG_FILE.exists():
        data = pd.read_csv(DIAG_FILE, index_col=0)
    else:
        tregime = dict(zip(polyg_reg.index, polyg_reg.poltime))
        data = compute_global_diagnostics(tregime, draw_plots=True)

    # fig 1a, example of polygonal regime
    map_topo_dim("sp_ra1e+06__c_eta100.0", plot_vel=False, annotation="a")

    # fig 1b, example of bottom heated convection
    map_topo_dim("BottomHeated", plot_vel=False, annotation="b")

    # write latex tables
    latex_table(data)
    latex_longtable()

    # fig 2a, b, c
    temps_mid_depth_dim()

    # fig 2d
    map_topo_dim("sp_ra1e+07__c_eta1000.0", plot_vel=True, annotation="d")

    # fig 2e, regime diagram
    regime_diagram(data, annotation="e")

    # fig 3, without the timescale, added by hand
    summary_constraints(data)

    # Fig 3 decomposed
    # summary_constraints_decomp(data)

    # Extended data

    # topo gradient - Extended data 2
    map_topo_grad("sp_ra1e+07__c_eta1000.0")
    map_topo_grad("sp_ra1e+07__c_eta3000.0")

    # time series - extended data 3
    time_series('sp_ra1e+07__c_eta1000.0')

    # theta vs RaH - extended data 5
    plot_T_RaH("sp_ra1e+07__c_eta1000.0", polyg_reg)

    # T and topo snapshots for case #14 - extended data 5
    map_topo_T("sp_ra1e+07__c_eta1000.0", dimensional=True)

    # T profiles for decaying convection - extended data 6
    plot_Tprofs_decay()

    # Time evolution for decaying convection - extended data 7
    plot_decay_time()

    # Time evolution for fluctuating case - extended data 8
    plot_timevar()

    # topo + T map for time varying calculation - Extended data 9
    map_topo_T("sp_ra1e+07__c_eta30000_FullPeriod", dimensional=True,
               istart=-1)

    # time series for Nubot = 1 - extended data 10
    time_series('sp_ra1e+07__c_eta1000_Nubot5', plot_qbot=True)

    # time series for Nubot = 5 - extended data 11
    time_series('sp_ra1e+07__c_eta1000_Nubot1', plot_qbot=True)

    map_topo_dim("sp_ra1e+07__c_eta1000_Nubot1", plot_vel=False,
                 annotation='time', istart=0)
    map_topo_dim("sp_ra1e+07__c_eta1000_Nubot5", plot_vel=False,
                 annotation='time', istart=0)
