import numpy as np
from collections.abc import Callable
from scipy.special import expi
import mpmath


def ρ0_jeffreys_terms(indices, w1: float, w2: float, ϖ: float = np.pi / 2) -> float:
    j, k, m, n = indices
    if k == m:
        if k == 0:
            return (1 - np.cos(ϖ)) ** n * (1 + np.cos(ϖ)) ** j / 2 ** (j + n)
        return (
            (1 - np.cos(ϖ)) ** n
            * (1 + np.cos(ϖ)) ** j
            * np.sin(ϖ) ** (2 * m)
            * (expi(-2 * m * w2) - expi(-2 * m / w2))
            / (np.log(w2) * 2 ** (j + k + m + n + 1))
        )
    return (
        (1 - np.cos(ϖ)) ** n
        * (1 + np.cos(ϖ)) ** j
        * np.sin(ϖ) ** (k + m)
        * (expi(-(k + m) * w2) - expi(-(k + m) / w2))
        * np.sin((k - m) * w1 / 2)
    ) / ((k - m) * w1 * np.log(w2) * 2 ** (j + k + m + n))


def ρ1_jeffreys_terms(indices, w1: float, w2: float, ϖ: float = np.pi / 2) -> float:
    j, k, m, n = indices
    if k == m:
        return 0
    return (
        1j
        * (1 - np.cos(ϖ)) ** n
        * (1 + np.cos(ϖ)) ** j
        * np.sin(ϖ) ** (k + m)
        * (expi(-(k + m) * w2) - expi(-(k + m) / w2))
        * ((w1 / 2) * (k - m) * np.cos((k - m) * w1 / 2) - np.sin((k - m) * w1 / 2))
        / (w1 * np.log(w2) * (k - m) ** 2 * 2 ** (j + k + m + n))
    )


def ρ2_jeffreys_terms_log(indices, w1: float, w2: float, ϖ: float = np.pi / 2) -> float:
    j, k, m, n = indices
    if k == m:
        if k == 0:
            return 0
        return (
            (1 - np.cos(ϖ)) ** n
            * (1 + np.cos(ϖ)) ** j
            * np.sin(ϖ) ** (2 * k)
            * (
                (k / 2)
                * (
                    w2 * float(mpmath.hyper([1, 1, 1], [2, 2, 2], -2 * k * w2))
                    - float(mpmath.hyper([1, 1, 1], [2, 2, 2], -2 * k / w2)) / w2
                )
                / np.log(w2)
                + (
                    expi(-2 * k * w2)
                    + expi(-2 * k / w2)
                    - 2 * (np.euler_gamma + np.log(2 * k))
                )
                / 4
            )
        ) / (2 ** (j + 2 * k + n))
    return (
        (1 - np.cos(ϖ)) ** n
        * (1 + np.cos(ϖ)) ** j
        * np.sin(ϖ) ** (k + m)
        * (
            ((k + m) / 2)
            * (
                w2 * float(mpmath.hyper([1, 1, 1], [2, 2, 2], -(k + m) * w2))
                - float(mpmath.hyper([1, 1, 1], [2, 2, 2], -(k + m) / w2)) / w2
            )
            / np.log(w2)
            + (expi(-(k + m) * w2) + expi(-(k + m) / w2)) / 2
            - (np.euler_gamma + np.log(k + m))
        )
        * np.sin((k - m) * w1 / 2)
        / (2 ** (j + k + m + n) * w1 * (k - m))
    )


def exact_state0(
    w1: float, w2: float, ϖ: float = np.pi / 2, copies: int = 1
) -> np.typing.NDArray:
    return multi_tensor(
        lambda j, k, m, n: ρ0_jeffreys_terms((j, k, m, n), w1, w2, ϖ), copies
    )


def exact_first_moments_log(
    w1: float, w2: float, ϖ: float = np.pi / 2, copies: int = 1
) -> np.typing.NDArray:
    return [
        multi_tensor(
            lambda j, k, m, n: ρ1_jeffreys_terms((j, k, m, n), w1, w2, ϖ), copies
        ),
        multi_tensor(
            lambda j, k, m, n: ρ2_jeffreys_terms_log((j, k, m, n), w1, w2, ϖ), copies
        ),
    ]


def num_digits_both(i: int, j: int) -> int:
    return np.binary_repr(np.bitwise_and(i, j)).count("1")


def num_powers(i: int, j: int, copies: int) -> tuple[int, int, int, int]:
    max_ones = 2**copies - 1
    return (
        num_digits_both(max_ones - i, max_ones - j),
        num_digits_both(max_ones - i, j),
        num_digits_both(i, max_ones - j),
        num_digits_both(i, j),
    )


def make_cache(
    copies: int, func: Callable[[int, int, int, int], float]
) -> dict[tuple[int, int, int, int], float]:
    return {
        (i, j, k, m): func(i, j, k, m)
        for i in range(copies + 1)
        for j in range(copies + 1 - i)
        for k in range(copies + 1 - i - j)
        for m in range(copies + 1 - i - j - k)
    }


def multi_tensor(func: Callable[[int, int, int, int], float], copies: int):
    cache = make_cache(copies, func)
    dim = 2**copies

    arr = np.array(
        [cache[num_powers(i, j, copies)] for i in range(dim) for j in range(dim)]
    ).reshape((dim, dim))
    return arr
