from mi_bmqe import as_incompat
import numpy as np

import matplotlib.pyplot as plt

if __name__ == "__main__":
    dat = np.load(
        "planar_qubit.npz",
    )
    alphas = dat["alphas"]
    spmbound_vs_alpha = dat["spmbound_vs_alpha"]
    nhbound_vs_alpha = dat["nhbound_vs_alpha"]
    pgm_opt_vs_alpha = dat["pgm_opt_vs_alpha"]
    msl_prior_vs_alpha = dat["msl_prior_vs_alpha"]
    w2s = dat["w2s"]
    spmbound_vs_w2 = dat["spmbound_vs_w2"]
    msl_prior_vs_w2 = dat["msl_prior_vs_w2"]
    nhbound_vs_w2 = dat["nhbound_vs_w2"]
    pgm_opt_vs_w2 = dat["pgm_opt_vs_w2"]

    ### Plotting
    plt.rcParams.update(
        {
            "text.usetex": True,
            "font.family": "Helvetica",
        }
    )

    ## First plot
    fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True)
    size_value = 16
    width_value = 2.2

    ax.plot(
        alphas,
        as_incompat(nhbound_vs_alpha, spmbound_vs_alpha),
        "-",
        linewidth=width_value,
        color="black",
        label=r"$\mathrm{NH}$",
    )

    ax.plot(
        alphas,
        as_incompat(2 * spmbound_vs_alpha, spmbound_vs_alpha),
        "--",
        linewidth=width_value,
        color="black",
        label=r"$\mathrm{PGM}$",
    )

    ax.plot(
        alphas,
        as_incompat(msl_prior_vs_alpha, spmbound_vs_alpha),
        ":",
        linewidth=width_value,
        color="black",
        label=r"$\mathrm{prior}$",
    )

    ax.plot(
        alphas,
        as_incompat(pgm_opt_vs_alpha, spmbound_vs_alpha),
        "-.",
        linewidth=width_value,
        color="black",
        label=r"$\mathrm{prior}$",
    )

    ax.set_xlabel(r"$\beta$", fontsize=size_value)
    ax.set_ylabel(r"$\mathcal{I}$", fontsize=size_value)
    ax.tick_params(axis="both", which="both", labelsize=size_value)
    ax.set_ylim(top=2.05)

    ax.fill_between(
        alphas,
        as_incompat(nhbound_vs_alpha, spmbound_vs_alpha),
        as_incompat(pgm_opt_vs_alpha, spmbound_vs_alpha),
        color="lightgrey",
        alpha=1,
        zorder=0,
    )

    ax_inset = fig.add_axes([0.72, 0.74, 0.23, 0.21])
    ax_inset.plot(
        alphas,
        spmbound_vs_alpha,
        "-",
        linewidth=0.75 * width_value,
        color="black",
    )

    ax_inset.set_xlabel(r"$\beta$", fontsize=0.9 * size_value)
    ax_inset.set_ylabel(r"$\mathcal{L}_{\mathrm{SPM}} $", fontsize=0.9 * size_value)
    ax_inset.tick_params(axis="both", which="both", labelsize=0.9 * size_value)
    ax_inset.set_xticks([0, 0.5, 1])
    ax_inset.set_yticks([0.3, 0.35, 0.4])

    plt.savefig("fig3a.pdf", format="pdf", bbox_inches="tight")

    ### Second plot
    fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True)
    size_value = 16
    width_value = 2.2

    ax.plot(
        w2s,
        as_incompat(nhbound_vs_w2, spmbound_vs_w2),
        "-",
        linewidth=width_value,
        color="black",
        label=r"$\mathrm{NH}$",
    )

    ax.plot(
        w2s,
        as_incompat(2 * spmbound_vs_w2, spmbound_vs_w2),
        "--",
        linewidth=width_value,
        color="black",
        label=r"$\mathrm{PGM}$",
    )

    ax.plot(
        w2s,
        as_incompat(msl_prior_vs_w2, spmbound_vs_w2),
        ":",
        linewidth=width_value,
        color="black",
        label=r"$\mathrm{prior}$",
    )

    ax.plot(
        w2s,
        as_incompat(pgm_opt_vs_w2, spmbound_vs_w2),
        "-.",
        linewidth=width_value,
        color="black",
        label=r"$\mathrm{prior}$",
    )

    ax.set_xlabel(r"$W_2$", fontsize=size_value)
    ax.set_ylabel(r"$\mathcal{I}$", fontsize=size_value)
    ax.tick_params(axis="both", which="both", labelsize=size_value)
    ax.set_ylim(top=2.05)

    ax.fill_between(
        w2s,
        as_incompat(nhbound_vs_w2, spmbound_vs_w2),
        as_incompat(pgm_opt_vs_w2, spmbound_vs_w2),
        color="lightgrey",
        alpha=1,
        zorder=0,
    )

    ax_inset = fig.add_axes([0.72, 0.74, 0.23, 0.21])
    ax_inset.plot(
        w2s,
        spmbound_vs_w2,
        "-",
        linewidth=0.75 * width_value,
        color="black",
    )

    ax_inset.set_xlabel(r"$W_2$", fontsize=0.9 * size_value)
    ax_inset.set_ylabel(r"$\mathcal{L}_{\mathrm{SPM}} $", fontsize=0.9 * size_value)
    ax_inset.tick_params(axis="both", which="both", labelsize=0.9 * size_value)
    ax_inset.set_xticks([0, 0.25, 0.5])

    plt.savefig("fig3b.pdf", format="pdf", bbox_inches="tight")
