import numpy as np
from tqdm import tqdm
from multipar_bayes import nh_fun, spm_fun, sqpm_fun
from mi_bmqe import as_incompat


def ρ0(d: int, α: complex) -> np.typing.NDArray:
    return np.diag([np.abs(α) ** 2] + [1 for _ in range(d)]) / (d + np.abs(α) ** 2)


def ρj(d: int, j: int, α: complex, n: int) -> np.typing.NDArray:
    r = np.zeros((d + 1, d + 1), dtype=np.complex128)
    coeff = α / (n * (d + np.abs(α) ** 2))
    r[0, j] = -1j * coeff
    r[j, 0] = 1j * coeff
    return r


def nh_loss(d: int, α: complex, n: int, L) -> float:
    ρ0_tmp = ρ0(d, α)
    ρjs_tmp = [ρj(d, j, α, n) for j in range(1, d + 1)]
    λ = (np.pi**2) / (3 * n**2)
    nh_bound, _, _, _ = nh_fun(ρ0_tmp, ρjs_tmp, L)
    nh_loss = λ - nh_bound
    return nh_loss


def pgm_loss(d: int, α: complex, n: int, L) -> float:
    ρ0_tmp = ρ0(d, α)
    ρjs_tmp = [ρj(d, j, α, n) for j in range(1, d + 1)]
    λ = (np.pi**2) / (3 * n**2)
    pgm_gain, _, _ = sqpm_fun(ρ0_tmp, ρjs_tmp, L)
    pgm_loss = 2 * λ - 2 * pgm_gain
    return pgm_loss


def prior_loss(d: int, α: complex, n: int, L) -> float:
    λ = (np.pi**2) / (3 * n**2)
    return λ


def pgm_star_loss(d: int, α: complex, n: int, L) -> float:
    λ = (np.pi**2) / (3 * n**2)
    pgm_star_gain = 2 * (abs(α) ** 2) / ((n**2) * ((d + (abs(α) ** 2)) ** 2))
    pgm_star_loss = λ - pgm_star_gain
    return pgm_star_loss


def spm_loss(d: int, α: complex, n: int, L) -> float:
    ρ0_tmp = ρ0(d, α)
    ρjs_tmp = [ρj(d, j, α, n) for j in range(1, d + 1)]
    λ = (np.pi**2) / (3 * n**2)
    spm_gain, _, _ = spm_fun(ρ0_tmp, ρjs_tmp, L)
    return λ - spm_gain


if __name__ == "__main__":
    n = 1
    max_d = 17
    d_values = np.arange(2, max_d + 1)

    loss_spm = []
    loss_nh = []
    loss_pgm = []
    loss_prior = []
    loss_pgm_star = []

    for d in (pbar := tqdm(d_values)):
        pbar.set_description(f"Calculating {d:>2} phases")
        α = d ** (1 / 4)
        L = np.eye(d) / d
        loss_nh.append(nh_loss(d, α, n, L))
        loss_pgm.append(pgm_loss(d, α, n, L))
        loss_prior.append(prior_loss(d, α, n, L))
        loss_pgm_star.append(pgm_star_loss(d, α, n, L))
        loss_spm.append(spm_loss(d, α, n, L))

    incompat_nh = as_incompat(loss_nh, loss_spm)
    incompat_pgm = as_incompat(loss_pgm, loss_spm)
    incompat_pgm_star = as_incompat(loss_pgm_star, loss_spm)
    incompat_prior = as_incompat(loss_prior, loss_spm)

    np.savez(
        "figure1.npz",
        incomp_nh=incompat_nh,
        incomp_pgm=incompat_pgm,
        incomp_prior=incompat_prior,
        incomp_pgm_star=incompat_pgm_star,
        loss_spm=loss_spm,
    )
