import numpy as np
import multiprocessing
import queue
import time
from collections.abc import Callable
from tqdm import tqdm

from multipar_bayes import pgm_fun, spm_fun, rpm_fun, nh_fun
from multipar_bayes.measurements import finite_local_measurement_bayesian_update
from mi_bmqe import bound_from_gain, as_incompat
from mi_bmqe.phase_dephasing import exact_state0, exact_first_moments_log

SINGLE_QUBIT_TOM_POVMS = [
    np.array([[1, 0], [0, 0]]) / 3,
    np.array([[0, 0], [0, 1]]) / 3,
    np.array([[1, 1], [1, 1]]) / 6,
    np.array([[1, -1], [-1, 1]]) / 6,
    np.array([[1, 1j], [-1j, 1]]) / 6,
    np.array([[1, -1j], [1j, 1]]) / 6,
]


def λ_p(w1: float, w2: float) -> np.typing.NDArray:
    return np.diag([w1**2 / 12, np.log(w2) ** 2 / 12])


def monitor_tqdm(queues, totals):
    bars = {
        name: tqdm(total=total, desc=name, position=i)
        for i, (name, total) in enumerate(totals.items())
    }

    done_counts = {n: 0 for n in queues.keys()}

    while any(
        done < total for done, total in zip(done_counts.values(), totals.values())
    ):
        for i, (name, q) in enumerate(queues.items()):
            try:
                update = q.get_nowait()
                bars[name].update(update)
                done_counts[name] += update
            except queue.Empty:
                pass
        time.sleep(0.1)

    for bar in bars.values():
        bar.close()


def evaluate_bound_worker(
    q_progress,
    q_result,
    f: Callable,
    ρ0s: list[np.typing.ArrayLike],
    ρ1ses: list[list[np.typing.ArrayLike]],
    *args,
) -> list[float]:
    result = []
    for ρ0, ρ1s, *extra_args in zip(ρ0s, ρ1ses, *args):
        result.append(f(ρ0, ρ1s, *extra_args)[0])
        q_progress.put(1)
        np.save(f"tmp/bound_{f.__name__}", result)
    q_result.put(result)


if __name__ == "__main__":
    w1 = np.pi / 2
    w2 = 5
    max_copies = 9

    copies = np.arange(1, max_copies + 1)

    state0s = [exact_state0(w1, w2, copies=c) for c in tqdm(copies)]
    first_momentses = [exact_first_moments_log(w1, w2, copies=c) for c in tqdm(copies)]
    lps = [np.trace(λ_p(w1, w2)) for _ in copies]
    prior_losses = lps
    bound_types = ["SPM", "RPM", "NHB", "TOM", "PGM"]

    progress_queues = {b: multiprocessing.Queue() for b in bound_types}
    result_queues = {b: multiprocessing.Queue() for b in bound_types}
    processes = {
        "SPM": multiprocessing.Process(
            target=evaluate_bound_worker,
            args=(
                progress_queues["SPM"],
                result_queues["SPM"],
                spm_fun,
                state0s,
                first_momentses,
            ),
        ),
        "RPM": multiprocessing.Process(
            target=evaluate_bound_worker,
            args=(
                progress_queues["RPM"],
                result_queues["RPM"],
                rpm_fun,
                state0s,
                first_momentses,
            ),
        ),
        "NHB": multiprocessing.Process(
            target=evaluate_bound_worker,
            args=(
                progress_queues["NHB"],
                result_queues["NHB"],
                nh_fun,
                state0s,
                first_momentses,
            ),
        ),
        "PGM": multiprocessing.Process(
            target=evaluate_bound_worker,
            args=(
                progress_queues["PGM"],
                result_queues["PGM"],
                pgm_fun,
                state0s,
                first_momentses,
                lps,
            ),
        ),
        "TOM": multiprocessing.Process(
            target=evaluate_bound_worker,
            args=(
                progress_queues["TOM"],
                result_queues["TOM"],
                finite_local_measurement_bayesian_update,
                state0s,
                first_momentses,
                [(SINGLE_QUBIT_TOM_POVMS, c) for c in copies],
            ),
        ),
    }

    for b in bound_types:
        processes[b].start()
    print("Started processes")

    monitor_tqdm(progress_queues, {b: max_copies for b in bound_types})

    res = {b: res_q.get() for b, res_q in result_queues.items()}

    spm_bound = bound_from_gain(res["SPM"], lps)
    rpm_bound = bound_from_gain(res["RPM"], lps)
    nhb_bound = bound_from_gain(res["NHB"], lps)
    tom_bound = bound_from_gain(res["TOM"], lps)
    pgm_bound = res["PGM"]

    rpm_incompat = as_incompat(rpm_bound, spm_bound)
    nhb_incompat = as_incompat(nhb_bound, spm_bound)
    pgm_incompat = as_incompat(pgm_bound, spm_bound)
    tom_incompat = as_incompat(tom_bound, spm_bound)

    np.savez(
        "figure2.npz",
        copies=copies,
        prior_losses=prior_losses,
        spm_bound=spm_bound,
        rpm_bound=rpm_bound,
        nhb_bound=nhb_bound,
        pgm_bound=pgm_bound,
        tom_bound=tom_bound,
    )
