import numpy as np


def inequal_constraint(k):
    return 0 <= k.count("1") <= 2


def equal_constraint(k):
    coeffs = [0, -1, -1, 2]
    return sum([coeffs[i] * int(bit) for i, bit in enumerate(list(k))]) == 0


def binary_to_portfolio_ord(binary):
    """
    The format of the portfolio selection is a list with the LSB first.
    This converts the str to an int list and reverses the order.
    """
    return list(map(lambda k: int(k), reversed(binary)))


def compute_cost(x, instance):
    _x = np.array([binary_to_portfolio_ord(x)]).T
    mu = np.expand_dims(instance["mu"], axis=1)
    return (-mu.T @ _x + instance["risk_factor"] * _x.T @ instance["sigma"] @ _x)[0][0]


def compute_projected_ar(counts, instance, constraint_fn):
    num_assets = len(instance.mu)
    num_shots = sum(counts.values())
    exp_cost = sum(
        [compute_cost(k, instance) * v / num_shots for k, v in counts.items()]
    )
    in_constraint_costs = [
        compute_cost(k, instance)
        for k in map(lambda t: bin(t)[2:].zfill(num_assets), 2**num_assets)
        if constraint_fn(k)
    ]
    return (exp_cost - min(in_constraint_costs)) / (
        max(in_constraint_costs) - min(in_constraint_costs)
    )


def plot_probability_curve(ax, data, constraint_fn, label="", ls="-"):

    probs = []
    x = []
    std_err = []

    for point in data:
        x.append(point["n_x"])
        processed_counts = process_raw_counts(point["raw_counts"])
        num_shots = sum(point["raw_counts"].values())
        counts_inconstraint = compute_counts_inconstraint(
            processed_counts, constraint_fn
        )
        probs.append(counts_inconstraint / num_shots)
        std_err.append(
            np.std([1] * counts_inconstraint + [0] * (num_shots - counts_inconstraint))
            / np.sqrt(num_shots)
        )
    ax.errorbar(
        x, probs, yerr=np.array(std_err).T, label=label, linestyle=ls, lw=1, capsize=4
    )


def plot_probability_bars(
    axs,
    point,
    instance,
    constraint_fn,
    color,
    label="",
    offset_scale=1,
    offset_shift=0,
    vlines=False,
):

    num_shots = sum(point["raw_counts"].values())

    processed_counts = process_raw_counts(point["raw_counts"])

    out_constraint_pts = sorted(
        [
            (val, compute_cost(bit, instance))
            for bit, val in processed_counts.items()
            if not constraint_fn(bit)
        ],
        key=lambda k: -k[1],
    )
    for t, (count, _) in enumerate(out_constraint_pts):
        axs.bar(offset_scale * t + offset_shift, count / num_shots, color=color)

    in_constaint_pts = sorted(
        [
            (val, compute_cost(bit, instance))
            for bit, val in processed_counts.items()
            if constraint_fn(bit)
        ],
        key=lambda k: -k[1],
    )

    if vlines:
        axs.vlines(
            offset_scale * (t + 1) + offset_shift, 0, 0.5, color="black", lw=1, ls="--"
        )
    added_label = False
    for t, (count, _) in enumerate(in_constaint_pts, t + 1):
        axs.bar(
            offset_scale * t + offset_shift + 1,
            count / num_shots,
            color=color,
            label=label if not added_label else "",
        )
        added_label = True
    axs.set_xticks([])
    axs.set_ylim(0, 0.4)
    axs.set_yticks([0, 0.2, 0.4])


def process_raw_counts(counts):
    num_qubits_for_selection = len(list(counts.keys())[0].split(" ")[0])
    processed_counts = {
        bin(i)[2:].zfill(num_qubits_for_selection): 0
        for i in range(2**num_qubits_for_selection)
    }
    for k, v in counts.items():
        # first register contains binary variables corresponding to assets
        selection = k.split(" ")[0]
        processed_counts[selection] += v
    return processed_counts


def map_counts_from_tket(counts, num_assets):
    _counts = {
        "".join(map(lambda k: str(k), reversed(key))): val
        for key, val in counts.items()
    }
    return {
        key[-num_assets:] + " " + key[:-num_assets]: val for key, val in _counts.items()
    }


def compute_counts_inconstraint(processed_counts, constraint_fn):
    return sum([v for k, v in processed_counts.items() if constraint_fn(k)])
