import numpy as np
from multipar_bayes import nh_fun

PAULI_SU = [
    np.array([[0, 1], [1, 0]]),
    np.array([[0, 1j], [-1j, 0]]),
    np.array([[1, 0], [0, -1]]),
]


if __name__ == "__main__":
    ρ0 = 0.5 * np.eye(2)

    # As a function of α for fixed widths w1 and w2
    w1 = 0.85
    w2 = 0.51

    # Alternative choices:
    #   w1 = 0.9,  w2 = 0.434
    #   w1 = 0.99, w2 = 0.1

    # print(f"The quantity w1^2 + w2^2 is {w1 ** 2 + w2 ** 2}")

    alphas = np.linspace(0.01, 1.0, 50)

    spmbound_vs_alpha = []
    nhbound_vs_alpha = []
    pgm_opt_vs_alpha = []
    msl_prior_vs_alpha = []

    for alpha in alphas:
        v1 = w1**2 / (1 + 2 * alpha)
        v2 = w2**2 / (1 + 2 * alpha)

        msl_prior = v1 + v2

        msl_prior_vs_alpha.append(msl_prior)
        spmbound_vs_alpha.append(msl_prior - (v1**2 + v2**2))
        pgm_opt_vs_alpha.append(msl_prior - (v1**3 + v2**3))

        ρ1s = [0.5 * v1 * PAULI_SU[0], 0.5 * v2 * PAULI_SU[1]]

        nhbound_vs_alpha.append(msl_prior - nh_fun(ρ0, ρ1s)[0])

        # This can be used for numerical checks, but we have exact analytical results:
        # spmbound_vs_alpha.append(msl_prior-spm_fun(ρ0, ρ1s)[0])

    spmbound_vs_alpha = np.array(spmbound_vs_alpha)
    msl_prior_vs_alpha = np.array(msl_prior_vs_alpha)
    nhbound_vs_alpha = np.array(nhbound_vs_alpha)
    pgm_opt_vs_alpha = np.array(pgm_opt_vs_alpha)

    ### As a function of the prior width w2, for fixed w1 and alpha
    alpha = 0.07
    w1 = 0.83

    w2s = np.linspace(0.01, np.sqrt(1 - w1**2), 50)

    spmbound_vs_w2 = []
    nhbound_vs_w2 = []
    msl_prior_vs_w2 = []
    pgm_opt_vs_w2 = []

    for w2 in w2s:
        v1 = w1**2 / (1 + 2 * alpha)
        v2 = w2**2 / (1 + 2 * alpha)

        msl_prior = v1 + v2

        msl_prior_vs_w2.append(msl_prior)
        spmbound_vs_w2.append(msl_prior - (v1**2 + v2**2))
        pgm_opt_vs_w2.append(msl_prior - (v1**3 + v2**3))

        ρ1s = [0.5 * v1 * PAULI_SU[0], 0.5 * v2 * PAULI_SU[1]]

        nhbound_vs_w2.append(msl_prior - nh_fun(ρ0, ρ1s)[0])

    spmbound_vs_w2 = np.array(spmbound_vs_w2)
    msl_prior_vs_w2 = np.array(msl_prior_vs_w2)
    nhbound_vs_w2 = np.array(nhbound_vs_w2)
    pgm_opt_vs_w2 = np.array(pgm_opt_vs_w2)

    np.savez(
        "planar_qubit",
        alphas=alphas,
        spmbound_vs_alpha=spmbound_vs_alpha,
        nhbound_vs_alpha=nhbound_vs_alpha,
        pgm_opt_vs_alpha=pgm_opt_vs_alpha,
        msl_prior_vs_alpha=msl_prior_vs_alpha,
        w2s=w2s,
        spmbound_vs_w2=spmbound_vs_w2,
        msl_prior_vs_w2=msl_prior_vs_w2,
        nhbound_vs_w2=nhbound_vs_w2,
        pgm_opt_vs_w2=pgm_opt_vs_w2,
    )
