import matplotlib.pyplot as plt
import matplotlib
import numpy as np
from strategies.utils.UTMDP import GreyUTMDP, ConfidenceMethod


# model only to access probability computation methods
model = GreyUTMDP()


def max_samples(epsilon, delta):
    # binary search, initialize via exponential growth
    num_samples_max = 1
    sufficient = False
    while not sufficient:
        # worst case is uniform sampling, so we check whether uniform (rounded) gives epsilon-interval
        num_successes = num_samples_max / 2
        lower, upper = model._get_probability_bounds(
            num_samples_max, num_successes, delta
        )
        sufficient = upper - lower <= epsilon
        if not sufficient:
            num_samples_max *= 2
    # set min from previous iteration
    num_samples_min = num_samples_max / 2 + 1
    # then check midpoint until convergence
    while num_samples_max > num_samples_min + 0.1:
        num_samples_to_check = (num_samples_min + num_samples_max) / 2
        num_successes = num_samples_to_check / 2
        lower, upper = model._get_probability_bounds(
            num_samples_to_check, num_successes, delta
        )
        sufficient = upper - lower <= epsilon
        if sufficient:
            num_samples_max = num_samples_to_check
        else:
            num_samples_min = num_samples_to_check + 1
    return num_samples_max


def ratio_cp(x, y):
    model.confidence_method = ConfidenceMethod.HOEFFDING
    y1 = max_samples(x, y)
    model.confidence_method = ConfidenceMethod.CLOPPER_PEARSON
    y2 = max_samples(x, y)
    return y1 / y2


X = np.logspace(np.log10(10e-4), np.log10(0.5), 100)
Y = np.logspace(np.log10(10e-4), np.log10(0.1), 100)
X, Y = np.meshgrid(X, Y)

Z = np.zeros(X.shape)
for i in range(len(X)):
    for j in range(len(Y)):
        Z[i, j] = ratio_cp(X[i, j], Y[i, j])

fig = plt.figure()
ax = plt.axes(projection="3d")
ax.contour3D(X, Y, Z, 50, cmap="viridis")
fig.colorbar(
    matplotlib.cm.ScalarMappable(
        norm=matplotlib.colors.Normalize(vmin=np.amin(Z), vmax=np.amax(Z)),
        cmap="viridis",
    ),
    ax=ax,
    fraction=0.03,
    pad=0.1,
)
ax.set_xlabel("epsilon")
ax.set_ylabel("delta")
ax.set_zlabel("ratio")

# Show the plot
plt.savefig(f"ratio_3d.png")

print("Created 3D plot")
