"""
    Author: Julia Maia
    Date: June 2023

    Creates plots related to the paper "The Mantle Viscosity Structure of Venus"
    by Maia, Wieczorek and Plesa

    Call as:	python create_figures.py figure_name figure_format
                eg: python create_figures.py multitaper_grid pdf
"""

import os
import sys

import matplotlib.pyplot as plt
import numpy as np

import cmcrameri.cm as cmcra
import pygmt

PATH_DATA = os.path.join("..", "data")
PATH_OUT = os.path.join("..", "figures")


def plot_multitaper_grid(format):
    """
    Map of multitaper localization in robinson projection
    """

    fig = pygmt.Figure()

    cmaprange = [0, 1]
    pygmt.makecpt(cmap="oslo", series=cmaprange, truncate=[0.2, 0.9])

    fig.grdimage(
        grid=os.path.join(PATH_DATA, "multitaper_localization.grd"),
        projection=f"N0/12c",
        frame=["ag", f"WsNE"],
    )

    PATH_OUTPUT = os.path.join(PATH_OUT, f"multitaper_grid.{format}")

    fig.savefig(PATH_OUTPUT)


def plot_localized_spectra(format):

    path = os.path.join(PATH_DATA, "admitcor.txt")
    data = np.loadtxt(path)

    larr = data[:, 0]

    COLOR_Sgg = "#273f87"
    COLOR_Shh = "#E76F51"

    fig = plt.figure()
    axa = fig.gca()
    axb = axa.twinx()
    axc = axa.twiny()

    axa.yaxis.label.set_color(COLOR_Sgg)
    axb.yaxis.label.set_color(COLOR_Shh)

    axa.plot(larr, data[:, 1], color=COLOR_Sgg, alpha=0.65, lw=2, ls=(0, (1, 0.5)))
    axa.plot(larr, data[:, 3], color=COLOR_Sgg, lw=2)
    axb.plot(larr, data[:, 2], color=COLOR_Shh, alpha=0.65, lw=2, ls=(0, (1, 0.5)))
    axb.plot(larr, data[:, 4], color=COLOR_Shh, lw=2)

    # axa.vlines([4, 35], 0, 100, color="silver", lw=1, zorder=-100)
    axa.set(
        xlim=(1, 100),
        ylim=(0, 100),
        xlabel=r"Spherical harmonic degree",
        ylabel=r"Admittance, mGal km$^{-1}$",
    )
    axb.set(
        ylim=(0, 1), ylabel=r"Correlation",
    )

    PATH_OUTPUT = os.path.join(PATH_OUT, f"localized_spectra.{format}")

    fig.savefig(PATH_OUTPUT)


def plot_viscosity_profiles(format):

    path = os.path.join(PATH_DATA, "samples_nominal.txt")

    samples = np.loadtxt(path, delimiter=",")
    etans = samples[:, :4]  # normalized viscosity values
    depths = samples[:, 4:7]  # depth of viscosity interfaces
    dms = samples[:, 7]  # depth of pass sheet

    # Number of viscosity interfaces
    Ninterf = 3

    # creating depth array from surface to core-mantle boundary
    dCint = int((6051877 - 3250000) / 1000)  # in km
    depths_arr = np.arange(dCint)

    etarange = [-4, 4]

    # 2d mesh grid of viscosity x depth
    leneta = 27
    etasampling = np.logspace(etarange[0], etarange[1], leneta)
    amin = np.zeros(Ninterf + 1, dtype=np.int8)
    X, Y = np.meshgrid(etasampling, depths_arr)

    # puting viscosity sample values into bins
    indresampeta = np.zeros(np.shape(etans))
    for j, etan in enumerate(etans):
        for i, k in enumerate(etan):
            amin[i] = np.argmin(abs(k - etasampling))
        indresampeta[j] = amin

    # creating visocisty  per depth profiles
    etansperdepth2 = np.zeros((len(etans), dCint))
    for i, k in enumerate(depths.astype(int)):
        etansperdepth2[i, : k[0]] = indresampeta[i][0]
        for j in range(1, Ninterf):
            etansperdepth2[i, k[j - 1] : k[j]] = indresampeta[i][j]
        etansperdepth2[i, k[Ninterf - 1] :] = indresampeta[i][Ninterf]

    # generating 2d histogram
    dim2prof = np.ones((dCint, leneta))
    for j in range(len(etansperdepth2)):
        for i, k in enumerate(etansperdepth2[j].astype(int)):
            dim2prof[i, k] += 1
    dim2prof_pc = dim2prof / len(etans)  # Normalizing

    # making contour plot
    fig = plt.figure(figsize=(4, 7))
    ax = fig.gca()
    vmax = 0.5
    levels = np.linspace(0, vmax, 11)
    im = ax.contourf(
        X,
        Y,
        dim2prof_pc,
        levels=np.round(levels, 2),
        cmap=cmcra.oslo_r,
        vmax=vmax,
        vmin=0,
    )
    ax.patch.set_facecolor("k")

    ax.set(
        ylim=(dCint, 0),
        xlim=(10 ** etarange[0], 10 ** etarange[1]),
        xscale="log",
        xlabel="Relative viscosity",
        ylabel="Depth (km)",
        xticks=[1e-2, 1e0, 1e2],
    )

    cax = fig.add_axes([1, 0.15, 0.07, 0.7])
    fig.colorbar(im, cax=cax, label="Fraction of models")

    PATH_OUTPUT = os.path.join(PATH_OUT, f"viscosity_profiles.{format}")

    fig.savefig(PATH_OUTPUT, bbox_inches="tight")


def plot_density_anomalies(format):

    fig = pygmt.Figure()

    cmaprange = [-1e7, 1e7]
    pygmt.makecpt(cmap="vik", reverse=True, series=cmaprange)

    fig.grdimage(
        grid=os.path.join(PATH_DATA, "density_mass-sheet.grd"),
        projection=f"N0/12c",
        frame=["ag", f"WsNE"],
    )
    fig.colorbar(
        frame=['x+l"mass-sheet, kg/m2"'], position="JBC+o0c/0.2c+w7c/0.3c",
    )

    PATH_OUTPUT = os.path.join(PATH_OUT, f"density_anomalies.{format}")

    fig.savefig(PATH_OUTPUT)


if __name__ == "__main__":

    if len(sys.argv) == 3:
        FORMAT = sys.argv[-1]

    else:
        FORMAT = "pdf"

    figure_name = sys.argv[1]

    if len(sys.argv) < 2:
        print("Expecting label of figure as argument. Exiting.")
        sys.exit()

    elif figure_name == "multitaper_grid":
        plot_multitaper_grid(FORMAT)

    elif figure_name == "localized_spectra":
        plot_localized_spectra(FORMAT)

    elif figure_name == "viscosity_profiles":
        plot_viscosity_profiles(FORMAT)

    elif figure_name == "density_anomalies":
        plot_density_anomalies(FORMAT)

    else:
        print("Figure name not found. Exiting.")
        sys.exit()
