import matplotlib.pyplot as plt
import numpy as np
from strategies.utils.UTMDP import GreyUTMDP, ConfidenceMethod


# instantiate model to access probability computation methods
d = 0.01
model = GreyUTMDP(delta=d)

phat = 1 / 2
epsilon = 0.01


def max_samples(epsilon, phat):
    # binary search, initialize via exponential growth
    num_samples_max = 1
    sufficient = False
    while not sufficient:
        # worst case is uniform sampling, so we check whether uniform (rounded) gives epsilon-interval
        num_successes = num_samples_max * phat
        lower, upper = model._get_probability_bounds(
            num_samples_max, num_successes, model.total_delta
        )
        sufficient = upper - lower <= epsilon
        if not sufficient:
            num_samples_max *= 2
    # set min from previous iteration
    num_samples_min = num_samples_max / 2 + 1
    # then check midpoint until convergence
    while num_samples_max > num_samples_min + 0.01:
        num_samples_to_check = (num_samples_min + num_samples_max) / 2
        num_successes = num_samples_to_check * phat
        lower, upper = model._get_probability_bounds(
            num_samples_to_check, num_successes, model.total_delta
        )
        sufficient = upper - lower <= epsilon
        if sufficient:
            num_samples_max = num_samples_to_check
        else:
            num_samples_min = num_samples_to_check + 1
    return num_samples_max


################################
### Plot for varying epsilon ###
################################

plt.figure(figsize=(5, 4))

plt.xscale("log")

# Generate epsilon values
eps_values = np.linspace(10e-4, 0.9, 100)
# x_values = np.logspace(np.log10(10e-4), np.log(0.9), 100)

# Calculate y values for each function
model.confidence_method = ConfidenceMethod.HOEFFDING
y1 = np.array([max_samples(eps, phat) for eps in eps_values])
model.confidence_method = ConfidenceMethod.CLOPPER_PEARSON
y2 = np.array([max_samples(eps, phat) for eps in eps_values])

# Plot the functions
# plt.plot(x_values, y1, label='Hoeffding')
plt.plot(eps_values, y1 / y2, label="Ratio")

# Add labels and a legend
plt.xlabel("desired precision ε")
# plt.xscale('log')
plt.ylabel("ratio of required samples")
# plt.yscale('log')
plt.legend()

# Show the plot
plt.savefig(f"ratio_eps_delta{d}.png")

plt.clf()

# save .csv
np.savetxt(
    f"ratio_eps_delta{d}.csv",
    np.column_stack((eps_values, y1 / y2)),
    delimiter=",",
    fmt="%.7f",
)

print("Created epsilon vs delta plot")

#############################
### Plot for varying phat ###
#############################

plt.figure(figsize=(5, 4))

plt.xscale("linear")
plt.yscale("log")

# Generate phat values
phat_values = np.linspace(0, 1, 1001)

# Calculate y values for each function
model.confidence_method = ConfidenceMethod.HOEFFDING
z1 = np.array([max_samples(epsilon, p) for p in phat_values])
model.confidence_method = ConfidenceMethod.CLOPPER_PEARSON
z2 = np.array([max_samples(epsilon, p) for p in phat_values])

plt.plot(phat_values, z1 / z2, label="Ratio")

# Add labels and a legend
plt.xlabel("sample success rate p̂")
# plt.xscale('log')
plt.ylabel("ratio of required samples")
# plt.yscale('log')
plt.legend()

# Show the plot
plt.savefig(f"ratio_phat_delta{d}.png")

plt.clf()

# save .csv
np.savetxt(
    f"ratio_phat_delta{d}.csv",
    np.column_stack((phat_values, z1 / z2)),
    delimiter=",",
    fmt="%.7f",
)

print("Created epsilon vs phat plot")

exit()

####################################################
### Plot for varying phat (only from 0.2 to 0.8) ###
####################################################

plt.figure(figsize=(5, 4))

# Generate phat values
phat_values = np.linspace(0.1, 0.9, 801)

# Calculate y values for each function
model.confidence_method = ConfidenceMethod.HOEFFDING
z1 = np.array([max_samples(epsilon, p) for p in phat_values])
model.confidence_method = ConfidenceMethod.CLOPPER_PEARSON
z2 = np.array([max_samples(epsilon, p) for p in phat_values])

plt.plot(phat_values, z1 / z2, label="Ratio")

# Add labels and a legend
plt.xlabel("sample success rate p̂")
# plt.xscale('log')
plt.ylabel("ratio of required samples")
# plt.yscale('log')
plt.legend()

# Show the plot
plt.savefig(f"ratio_phat_delta{d}_zoomed.png")
