"""Interval MDP representation for PAC-Learning"""

import math
import random
import numpy as np
import itertools

from scipy.special import betaincinv, ndtri, beta, erfcinv
from scipy.optimize import minimize, Bounds, LinearConstraint
from collections import defaultdict
from enum import Enum
from strategies.Evaluation import Evaluation
import time
import warnings

warnings.filterwarnings("error", category=RuntimeWarning)


class DeltaDistribution(Enum):
    UNIFORM = 1
    UNIFORM_PRODUCT = 2
    MINIMIZE_INTERVAL_SUM_WITH_JACOBIAN = 3
    MINIMIZE_INTERVAL_SUM = 4
    MINIMIZE_INTERVAL_SUM_HEURISTIC = 5
    MINIMIZE_VALUE_INTERVAL = 6
    MINIMIZE_VALUE_INTERVAL_WITH_JACOBIAN = 7


class ConfidenceMethod(Enum):
    HOEFFDING = 1
    CLOPPER_PEARSON = 2
    WILSON_CORRECTED = 3
    # HOEFFDING_UPPER
    # HOEFFDING_LOWER


class GreyUTMDP:
    VALUE_ITERATION_CUTOFF = 0.0001

    def __init__(
        self,
        delta=0.1,
        pmin=10e-5,
        confidence_method=ConfidenceMethod.HOEFFDING,
        delta_distribution=DeltaDistribution.UNIFORM,
        small_support_improvement=True,
        min_property=False,
    ):
        """
        @param delta: error tolerance for the entire model
        @param pmin: minimum transition probability
        @param delta_distribution: how to distribute the error tolerance over all transitions
        @param confidence_method: The methods to use for calculating probability bounds. default is Hoeffding bound
        @param small_support_improvement: spend no delta in deterministic transitions and use the fact that for binary
        transitions the second transition is fully determined by the first one
        @param min_property: whether to find a minimizing strategy in value iteration
        """
        # general MDP specification
        self.states = set()
        self.removed_unreachable_states = set()
        self.action_mapping = dict()
        self.chain_mapping = dict()
        self.essential_mapping = dict()
        self.initial_states = list()
        self.rewards = dict()
        self.sink_states = set()
        self.actions = dict()
        self.transitions = dict()
        self.inverse_transitions = defaultdict(
            lambda: defaultdict(lambda: defaultdict(float))
        )
        self.samples = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
        self.values = dict()
        self.qualities = dict()
        self.total_delta = delta
        self.transition_deltas = defaultdict(
            lambda: defaultdict(lambda: defaultdict(float))
        )
        self.optimistic_transition_matrix = defaultdict(
            lambda: defaultdict(lambda: defaultdict(float))
        )
        self.pessimistic_transition_matrix = defaultdict(
            lambda: defaultdict(lambda: defaultdict(float))
        )
        # evaluation parameters
        self.confidence_method = confidence_method
        self.delta_distribution = delta_distribution
        self.small_support_improvement = small_support_improvement
        self.min_property = min_property
        # min transition probability
        self.pmin = pmin
        # desired accuracy
        self.required_samples = defaultdict(dict)
        # evaluation functions
        self.evaluation = Evaluation(self)

    def __str__(self):
        out = ""
        for state, actions in self.transitions.items():
            out += str(state) + " (reward " + str(self.rewards[state]) + ")\n"
            for action, transitions in actions.items():
                out += "\t" + str(action) + "\n"
                for t in transitions:
                    out += (
                        "\t\t"
                        + str(t.successor)
                        + ": ["
                        + str(t.lower)
                        + ", "
                        + str(t.upper)
                        + "] (sampled "
                        + str(self.samples[state][action][t.successor])
                        + " times)"
                        + ", true probability: "
                        + str(t.true)
                        + "\n"
                    )
        return out

    ##########################################
    # BUILDING THE MODEL FROM AN ENVIRONMENT #
    ##########################################

    def build_model(self, environment):
        """build an UTMDP model of the environment by exploring all states"""
        next_states = environment.get_initial_states()
        self.initial_states = environment.get_initial_states()
        while next_states:
            next_states = self.explore(environment, next_states)
        self.calculate_delta()

    def explore(self, environment, states):
        """explore the states in a simple BFS"""
        next_states = set()
        for state in states:
            # add state to model
            a = environment.get_actions(state)
            r = environment.get_reward(state)
            g = environment.is_final_state(state)
            self.add_state(state, a, r, g)
            if not g:
                for act in a:
                    self.transitions[state][act] = set()
                    successor_probabilities = (
                        environment.get_successors_with_probabilities(state, act)
                    )
                    for successor, probability in successor_probabilities.items():
                        t = UncertainTransition(successor)
                        t.lower = self.pmin if len(successor_probabilities) > 1 else 1
                        t.mean = 1 / len(successor_probabilities)
                        t.true = probability
                        self.transitions[state][act].add(t)
                        if successor not in self.states:
                            next_states.add(successor)
        return next_states

    def add_state(self, state, actions, reward, is_goal):
        self.states.add(state)
        self.actions[state] = []
        if not is_goal:
            self.actions[state] = actions
        self.values[state] = UncertainValue()
        self.qualities[state] = {}
        self.samples[state] = defaultdict(lambda: defaultdict(int))
        for a in actions:
            self.qualities[state][a] = UncertainValue()
            self.samples[state][a] = defaultdict(int)
        self.rewards[state] = reward
        self.transitions[state] = dict()
        for act in actions:
            self.transitions[state][act] = set()
        if is_goal:
            self.sink_states.add(state)

    ###################################
    # TRANSFORMING THE (INTERVAL) MDP #
    ###################################

    def reduce(self, state_actions_to_be_removed):
        for state, act in state_actions_to_be_removed:
            self.actions[state].remove(act)
            del self.qualities[state][act]
            del self.transitions[state][act]
        self.remove_unreachable_states()

    def remove_unreachable_states(self):
        reachable_states = set(self.initial_states)
        new_reachable_states = set(self.initial_states)
        while new_reachable_states:
            next_reachable_states = set()
            for state in new_reachable_states:
                for act in self.actions[state]:
                    successors = {t.successor for t in self.transitions[state][act]}
                    for successor in successors:
                        if successor not in reachable_states:
                            next_reachable_states.add(successor)
            reachable_states.update(next_reachable_states)
            new_reachable_states = next_reachable_states

        unreachable_states = set(self.states) - reachable_states
        for state in unreachable_states:
            self.remove_state(state)
        self.removed_unreachable_states.update(unreachable_states)

    def remove_state(self, state):
        self.states.remove(state)
        if state in self.sink_states:
            self.sink_states.remove(state)
        del self.actions[state]
        del self.values[state]
        del self.qualities[state]
        del self.transitions[state]
        del self.rewards[state]
        del self.samples[state]

    def handle_self_loops(self):
        sink_actions = set()
        for state in self.sink_states:
            for action in self.actions[state]:
                sink_actions.add((state, action))
        self.reduce(sink_actions)
        for state, actions in self.transitions.items():
            if state in self.sink_states:
                continue
            actions_to_remove = set()
            for action, transitions in actions.items():
                transitions_to_remove = set()
                for transition in transitions:
                    if transition.successor == state:
                        transitions_to_remove.add(transition)
                        if len(transitions) == 1:
                            # deterministic self loop
                            if self.min_property:
                                # if Pmin property we cannot remove deterministic self loop.
                                # instead, taking the self loop is a trivial minimizing strategy, so we can remove
                                # all other actions and add the state to the sink states
                                transitions_to_remove.remove(transition)
                                self.sink_states.add(state)
                                self.values[state].lower = 0
                                self.values[state].upper = 0
                                self.qualities[state][action].lower = 0
                                self.qualities[state][action].upper = 0
                                actions_to_remove.update(actions.keys())
                                actions_to_remove.remove(action)
                            else:
                                # if Pmax property, we can remove the loop
                                actions_to_remove.add(action)
                        else:
                            # if not deterministic, remove loop and normalize probability vector
                            normalization_factor = 1 / (1 - transition.true)
                            for t in transitions:
                                t.true *= normalization_factor
                            continue
                for t in transitions_to_remove:
                    transitions.remove(t)
            for action in actions_to_remove:
                self.actions[state].remove(action)
                del self.qualities[state][action]
                del self.transitions[state][action]

    def compute_inverse_transitions(self):
        self.inverse_transitions = defaultdict(
            lambda: defaultdict(lambda: defaultdict(float))
        )
        for state, actions in self.transitions.items():
            for action, transitions in actions.items():
                for transition in transitions:
                    self.inverse_transitions[transition.successor][state][action] = (
                        transition.true
                    )

    def collapse_prob01(self, use_property=True):
        self.handle_self_loops()
        self.compute_inverse_transitions()
        prob1_states = set()
        prob0_states = set()
        new_prob1_states = {g for g in self.sink_states if self.rewards[g] == 1}
        new_prob0_states = {g for g in self.sink_states if self.rewards[g] == 0}
        prob1_representative = next(iter(new_prob1_states))
        prob0_representative = next(iter(new_prob0_states))
        while new_prob1_states:
            candidates = (
                set(s for g in new_prob1_states for s in self.inverse_transitions[g])
                - prob1_states
            )
            prob1_states.update(new_prob1_states)
            next_prob1_states = set()
            for s in candidates:
                if s in self.sink_states:
                    continue
                # if max objective, one action where all successors are prob1 suffices, otherwise all actions
                # must always lead to prob1 state
                if use_property and not self.min_property:
                    if any(
                        all(
                            t.successor in prob1_states
                            or t.successor in new_prob1_states
                            for t in self.transitions[s][a]
                        )
                        for a in self.transitions[s]
                    ):
                        next_prob1_states.add(s)
                else:
                    if all(
                        t.successor in prob1_states or t.successor in new_prob1_states
                        for a in self.transitions[s]
                        for t in self.transitions[s][a]
                    ):
                        next_prob1_states.add(s)
            new_prob1_states = next_prob1_states

        while new_prob0_states:
            candidates = (
                set(s for g in new_prob0_states for s in self.inverse_transitions[g])
                - prob0_states
            )
            prob0_states.update(new_prob0_states)
            next_prob0_states = set()
            new_prob0_states = set()
            for s in candidates:
                if s in self.sink_states:
                    continue
                if use_property and self.min_property:
                    if any(
                        all(
                            t.successor in prob0_states
                            or t.successor in new_prob0_states
                            for t in self.transitions[s][a]
                        )
                        for a in self.transitions[s]
                    ):
                        next_prob0_states.add(s)
                else:
                    if all(
                        t.successor in prob0_states or t.successor in new_prob0_states
                        for a in self.transitions[s]
                        for t in self.transitions[s][a]
                    ):
                        next_prob0_states.add(s)
            new_prob0_states = next_prob0_states

        # merge states
        for s, acts in self.transitions.items():
            for a, ts in acts.items():
                for t in ts:
                    if t.successor in prob0_states:
                        t.successor = prob0_representative
                    elif t.successor in prob1_states:
                        t.successor = prob1_representative
        for s in prob0_states:
            if s == prob0_representative:
                continue
            self.remove_state(s)
            self.essential_mapping[s] = prob0_representative
        for s in prob1_states:
            if s == prob1_representative:
                continue
            self.remove_state(s)
            self.essential_mapping[s] = prob1_representative
        # clean up
        self.merge_same_target_transitions()
        self.handle_self_loops()
        self.remove_duplicate_transitions()
        self.compute_inverse_transitions()

    def merge_essential_states(self):
        self.handle_self_loops()
        self.compute_inverse_transitions()
        done = True
        # find 1-step essential states, i.e. single-action deterministic transitions
        # note: in a general this is not sufficient to find all essential states, but because we removed self-loops
        # this is guaranteed to exist for every state that is essential w.r.t. some other state

        # essential_state_merges[s] = t means s was merged into t.
        # we keep track of this so if we want to merge something into s, we merge into t instead.
        while not done:
            done = True
            essential_state_merges = dict()
            for state in self.states:
                if state in self.sink_states:
                    continue
                for predecessor, actions in self.inverse_transitions[state].items():
                    if (
                        predecessor in self.sink_states
                        or len(self.transitions[state].keys()) > 1
                    ):
                        continue
                    is_det = True
                    for act, successors in self.transitions[predecessor].items():
                        if len(successors) > 1:
                            is_det = False
                    if is_det:
                        done = False
                        state_to_merge_into = state
                        if state in essential_state_merges:
                            state_to_merge_into = essential_state_merges[state]
                        essential_state_merges[predecessor] = state_to_merge_into
                        # handle initial states
                        if predecessor in self.initial_states:
                            self.initial_states.remove(predecessor)
                            self.initial_states.append(state_to_merge_into)
                        if predecessor in self.sink_states:
                            self.sink_states.remove(predecessor)
                            self.sink_states.add(state_to_merge_into)
                        # redirect all transitions from prepre to predecessor towards state
                        # for now we might introduce duplicate transitions which we merge later
                        for prepre, actions in self.inverse_transitions[
                            predecessor
                        ].items():
                            for action, prob in actions.items():
                                for t in self.transitions[prepre][action]:
                                    if t.successor == predecessor:
                                        t.successor = state_to_merge_into
            for s, representative in essential_state_merges.items():
                self.essential_mapping[s] = representative
                for state, mapping in self.essential_mapping[s]:
                    if s == mapping:
                        self.essential_mapping[state] = representative
            # clean up model
            for s in essential_state_merges.keys():
                self.remove_state(s)
            self.remove_duplicate_transitions()
            self.merge_same_target_transitions()
            self.handle_self_loops()
            self.compute_inverse_transitions()

    def contract_chains(self):
        self.handle_self_loops()
        self.compute_inverse_transitions()
        done = False
        while not done:
            done = True
            states_to_remove = set()
            for state, actions in self.actions.items():
                if state in self.sink_states or state in self.initial_states:
                    continue
                if len(actions) == 1:
                    num_predecessors = sum(
                        len(acts) for s, acts in self.inverse_transitions[state].items()
                    )
                    if num_predecessors > 1:
                        continue
                    action = actions[0]
                    done = False
                    # be careful not to remove any states yet that have transitions to states that are waiting to be
                    # removed - we skip those for now and handle them in a later iteration
                    skip = False
                    for t in self.transitions[state][action]:
                        if t.successor in states_to_remove:
                            skip = True
                    # also be careful not to remove any states yet that have transitions FROM states that are waiting
                    # to be removed - we skip those for now and handle them in a later iteration
                    for predecessor in self.inverse_transitions[state]:
                        if predecessor in states_to_remove:
                            skip = True
                    if skip:
                        continue
                    state_action_transitions = {
                        t.successor: t.true for t in self.transitions[state][action]
                    }
                    states_to_remove.add(state)
                    # iterate over all predecessors
                    for succ, leaving_prob in state_action_transitions.items():
                        for s, acts in self.inverse_transitions[state].items():
                            for a, entering_prob in acts.items():
                                transitions_to_remove = set()
                                for t in self.transitions[s][a]:
                                    if t.successor == state:
                                        transitions_to_remove.add(t)
                                for t in transitions_to_remove:
                                    self.transitions[s][a].remove(t)
                                joined_transition_probability = (
                                    leaving_prob * entering_prob
                                )
                                for t in self.transitions[s][a]:
                                    if t.successor == succ:
                                        t.true += joined_transition_probability
                                        break
                                else:
                                    new_transition = UncertainTransition(succ)
                                    new_transition.true = joined_transition_probability
                                    self.transitions[s][a].add(new_transition)
            for state in states_to_remove:
                assert len(self.inverse_transitions[state]) == 1
                predecessor, transitions = next(
                    iter(self.inverse_transitions[state].items())
                )
                assert len(transitions) == 1
                action = next(iter(transitions.keys()))
                self.chain_mapping[state] = predecessor, action
                for s, (pred, a) in self.chain_mapping.items():
                    if pred == state:
                        self.chain_mapping[s] = (predecessor, action)
                self.remove_state(state)
            # redirecting may introduce self loops again - remove them
            self.remove_duplicate_transitions()
            self.compute_inverse_transitions()
            self.handle_self_loops()

    def structural_improvements(self):
        # merging essential states may create new chains [i think, not 100% sure] but definitely not the other way
        # around, so we first merge essential states and then chains since afterwards there are def no essential states
        self.collapse_prob01()
        self.merge_essential_states()
        self.contract_chains()

    def remove_duplicate_transitions(self):
        # remove duplicate actions from MDP. Since probabilities are unknown we can only remove deterministic
        # transitions with the same successor
        for state, actions in self.transitions.items():
            actions_to_remove = set()
            for a1, a2 in itertools.combinations(actions.keys(), 2):
                if a1 in actions_to_remove or a2 in actions_to_remove:
                    continue
                successors1 = self.transitions[state][a1]
                successors2 = self.transitions[state][a2]
                if len(successors1) == 1 and len(successors2) == 1:
                    succ1 = next(iter(successors1)).successor
                    succ2 = next(iter(successors2)).successor
                    if succ1 == succ2:
                        self.action_mapping[(state, a2)] = a1
                        actions_to_remove.add(a2)
            for a in actions_to_remove:
                del self.transitions[state][a]
                self.actions[state].remove(a)

    def merge_same_target_transitions(self):
        for state, actions in self.transitions.items():
            for action, transitions in actions.items():
                transitions_to_remove = set()
                for t1, t2 in itertools.combinations(transitions, 2):
                    if t1 in transitions_to_remove or t2 in transitions_to_remove:
                        continue
                    if t1.successor == t2.successor:
                        t1.true += t2.true
                        transitions_to_remove.add(t2)
                for t in transitions_to_remove:
                    transitions.remove(t)

    ############
    # SAMPLING #
    ############

    def get_sample(self, state, action):
        r = random.random()
        probability_sum = 0
        for transition in self.transitions[state][action]:
            probability_sum += transition.true
            if probability_sum > r:
                return transition.successor

    def get_num_prob_samples(self):
        num = 0
        for s, acts in self.transitions.items():
            for act, ts in self.transitions.items():
                if len(ts) > 1:
                    num += sum(n for succ, n in self.samples[s][act].items())
        return num

    #########################
    # BUILDING THE ESTIMATE #
    #########################

    def compute_required_samples(self, epsilon_per_transition):
        self.calculate_delta()
        # cache pre-computed values since all transitions with the same min delta require
        num_samples_per_delta = defaultdict(int)
        total_samples_required = 0
        for state, actions in self.transitions.items():
            if state in self.sink_states:
                continue
            for action, transitions in actions.items():
                if len(transitions) == 1:
                    continue
                # worst-case: the transition with the minimum delta has probability 1/2
                min_delta = max(
                    self.transition_deltas[state][action][t.successor]
                    for t in transitions
                )
                if min_delta in num_samples_per_delta:
                    total_samples_required += num_samples_per_delta[min_delta]
                else:
                    # binary search, initialize via exponential growth
                    num_samples_max = 1
                    sufficient = False
                    while not sufficient:
                        # worst case is uniform sampling, so we check whether uniform (rounded) gives epsilon-interval
                        num_successes = num_samples_max // 2
                        lower, upper = self._get_probability_bounds(
                            num_samples_max, num_successes, min_delta
                        )
                        sufficient = upper - lower <= epsilon_per_transition
                        if not sufficient:
                            num_samples_max *= 2
                    # set min from previous iteration
                    num_samples_min = num_samples_max // 2 + 1
                    # then check midpoint until convergence
                    while num_samples_max > num_samples_min + 1:
                        num_samples_to_check = (num_samples_min + num_samples_max) // 2
                        num_successes = num_samples_to_check // 2
                        lower, upper = self._get_probability_bounds(
                            num_samples_to_check, num_successes, min_delta
                        )
                        sufficient = upper - lower <= epsilon_per_transition
                        if sufficient:
                            num_samples_max = num_samples_to_check
                        else:
                            num_samples_min = num_samples_to_check + 1
                    total_samples_required += num_samples_max
                    num_samples_per_delta[min_delta] = num_samples_max
        return total_samples_required

    def _get_probability_bounds(self, num_samples, num_successes, delta):
        if num_samples == 0 or delta == 0:
            return self.pmin, 1
        if self.confidence_method == ConfidenceMethod.HOEFFDING:
            mean = num_successes / num_samples
            ci = math.sqrt(math.log(delta / 2) / (-2 * num_samples)) if delta > 0 else 1
            upper = min(mean + ci, 1)
            lower = max(mean - ci, self.pmin)
        elif self.confidence_method == ConfidenceMethod.CLOPPER_PEARSON:
            num_failures = num_samples - num_successes
            p_lower = (
                betaincinv(num_successes, num_failures + 1, delta / 2)
                if num_successes > 0
                else 0
            )
            p_upper = (
                betaincinv(num_successes + 1, num_failures, 1 - delta / 2)
                if num_failures > 0
                else 1
            )
            lower = max(self.pmin, p_lower)
            upper = min(p_upper, 1)
        elif self.confidence_method == ConfidenceMethod.WILSON_CORRECTED:
            z = ndtri(1 - delta / 2)
            p = num_successes / num_samples
            n = num_samples
            p_lower = (
                (
                    2 * n * p
                    + z**2
                    - (
                        z * math.sqrt(z**2 - 1 / n + 4 * n * p * (1 - p) + (4 * p - 2))
                        + 1
                    )
                )
                / (2 * (n + z**2))
                if z**2 - 1 / n + 4 * n * p * (1 - p) + (4 * p - 2) >= 0
                and z < float("inf")
                else 0
            )
            p_upper = (
                (
                    2 * n * p
                    + z**2
                    + (
                        z * math.sqrt(z**2 - 1 / n + 4 * n * p * (1 - p) - (4 * p - 2))
                        + 1
                    )
                )
                / (2 * (n + z**2))
                if z**2 - 1 / n + 4 * n * p * (1 - p) - (4 * p - 2) >= 0
                and z < float("inf")
                else 1
            )
            lower = max(self.pmin, p_lower)
            upper = min(p_upper, 1)
        else:
            raise NotImplementedError(
                f"You specified an invalid methods for computing probability intervals: "
                f"{self.confidence_method}"
            )
        return lower, upper

    def _get_derivative(self, delta, num_samples, num_successes, separate=False):
        # derivative of the probability interval size w.r.t. delta
        if num_samples == 0 or delta <= 0:
            if separate:
                return 0, 0
            return 0
        if self.confidence_method == ConfidenceMethod.HOEFFDING:
            if separate:
                return (
                    (
                        -1
                        / (2 * delta * math.sqrt(num_samples * math.log(4 / delta**2))),
                        -1
                        / (2 * delta * math.sqrt(num_samples * math.log(4 / delta**2))),
                    )
                    if delta > 0
                    else (0, 0)
                )
            return (
                -1 / (delta * math.sqrt(num_samples * math.log(4 / delta**2)))
                if delta > 0
                else 0
            )
        elif self.confidence_method == ConfidenceMethod.CLOPPER_PEARSON:
            a = num_successes
            b = num_samples - num_successes
            w_lower = betaincinv(a, b + 1, delta / 2) if a > 0 else 0
            w_upper = betaincinv(a + 1, b, 1 - delta / 2) if b > 0 else 1
            # For large a and b we can run into float issues as beta(a, b) is very small.
            # Hence we manually construct a list of all the factors and multiply them one by one,
            # alternating between factors >1 and <1 in such a way that the product stays near 1
            try:
                d_lower = (
                    1 / 2 * beta(a, b + 1) / (1 - w_lower) ** b / w_lower ** (a - 1)
                    if a > 0
                    else 0
                )
            except RuntimeWarning:
                d_lower = 1
                factors = (
                    [1 / i for i in range(1, a + b + 1)]
                    + [i for i in range(1, a)]
                    + [i for i in range(1, b + 1)]
                    + b * [1 / (1 - w_lower)]
                    + (a - 1) * [1 / w_lower]
                )
                factors.sort()
                while factors:
                    if d_lower >= 1:
                        d_lower *= factors.pop(0)
                    else:
                        d_lower *= factors.pop(-1)
                d_lower *= 1 / 2
            try:
                d_upper = (
                    -1 / 2 * beta(a + 1, b) / (1 - w_upper) ** (b - 1) / w_upper**a
                    if b > 0
                    else 0
                )
            except RuntimeWarning:
                d_upper = 1
                factors = (
                    [1 / i for i in range(1, a + b + 1)]
                    + [i for i in range(1, a + 1)]
                    + [i for i in range(1, b)]
                    + (b - 1) * [1 / (1 - w_upper)]
                    + a * [1 / w_upper]
                )
                factors.sort()
                while factors:
                    if d_upper >= 1:
                        d_upper *= factors.pop(0)
                    else:
                        d_upper *= factors.pop(-1)
                d_upper *= -1 / 2
            if separate:
                return d_upper, d_lower
            return d_upper - d_lower
        elif self.confidence_method == ConfidenceMethod.WILSON_CORRECTED:
            if num_samples == 0 or delta == 0:
                if separate:
                    return 0, 0
                return 0
            p = num_successes / num_samples
            n = num_samples
            z = ndtri(1 - delta / 2)
            # these terms have no direct meaning, but appear multiple times in the derivative, so we substitute
            x_lower = (
                math.sqrt(4 * n * (1 - p) * p - 1 / n + 4 * p + z**2 - 2)
                if 4 * n * (1 - p) * p - 1 / n + 4 * p + z**2 - 2 >= 0
                else None
            )
            x_upper = (
                math.sqrt(4 * n * (1 - p) * p - 1 / n - 4 * p + z**2 + 2)
                if 4 * n * (1 - p) * p - 1 / n - 4 * p + z**2 + 2 >= 0
                else None
            )
            # compute the derivative of each bound w.r.t. z
            # see https://www.wolframalpha.com/input?i=d%2Fdz+%282np%2Bz%5E2-%28%28z+sqrt%28z%5E2-1%2Fn+%2B+4np%281-p%29+%2B+%284p-2%29%29%29+%2B1%29%29+%2F+%282%28n%2Bz%5E2%29%29
            # and https://www.wolframalpha.com/input?i=d%2Fdz+%282np%2Bz%5E2%2B%28z+sqrt%28z%5E2-1%2Fn+%2B+4np%281-p%29+-+%284p-2%29%29%29+%2B1%29+%2F+%282%28n%2Bz%5E2%29%29
            d_z_lower = (
                (-(z**2) / x_lower - x_lower + 2 * z) / (2 * (n + z**2))
                - z * (-z * x_lower + 2 * n * p + z**2 - 1) / (n + z**2) ** 2
                if num_successes > 0 and z < float("inf") and x_lower is not None
                else 0
            )
            d_z_upper = (
                (z**2 / x_upper + x_upper + 2 * z) / (2 * (n + z**2))
                - z * (z * x_upper + 2 * n * p + z**2 + 1) / (n + z**2) ** 2
                if num_successes < num_samples
                and z < float("inf")
                and x_upper is not None
                else 0
            )
            # compute derivative of z w.r.t. delta
            d_z_delta = (
                -1 / 2 * math.sqrt(2 * math.pi) * math.e ** (erfcinv(2 * delta) ** 2)
            )
            # chain rule: derivative of bound w.r.t. delta is the product of [derivative of bound w.r.t. z] and
            # [derivative of z w.r.t. delta]
            if separate:
                return d_z_upper * d_z_delta, d_z_lower * d_z_delta
            return d_z_upper * d_z_delta - d_z_lower * d_z_delta
        else:
            raise NotImplementedError(
                f"Minimizing interval sums via gradient descent is not supported for the "
                f"confidence interval method {self.confidence_method.name}"
            )

    def calculate_delta(self):
        num_transitions = sum(
            len(transitions)
            for state, actions in self.transitions.items()
            for action, transitions in actions.items()
        )
        num_probabilistic_transitions = sum(
            len(transitions)
            if len(transitions) > 1 and state not in self.sink_states
            else 0
            for state, actions in self.transitions.items()
            for action, transitions in actions.items()
        )
        # avoid division by 0 error
        num_probabilistic_transitions = max(1, num_probabilistic_transitions)

        # UNIFORM #
        if self.delta_distribution == DeltaDistribution.UNIFORM:
            delta_per_transition = (
                self.total_delta / num_probabilistic_transitions
                if self.small_support_improvement
                else self.total_delta / num_transitions
            )
            for state, actions in self.transitions.items():
                if state in self.sink_states:
                    continue
                for action, transitions in actions.items():
                    if len(transitions) == 1 and self.small_support_improvement:
                        continue
                    for transition in transitions:
                        if len(transitions) == 2 and self.small_support_improvement:
                            self.transition_deltas[state][action][
                                transition.successor
                            ] = 2 * delta_per_transition
                        else:
                            self.transition_deltas[state][action][
                                transition.successor
                            ] = delta_per_transition
        elif self.delta_distribution == DeltaDistribution.UNIFORM_PRODUCT:
            # want: prod_{s,a} (1 - |Post(s,a)| * delta_t) >= 1-delta
            # we first calculate how often each branching factor appears to simplify the product
            # then we do a binary search on delta_t
            branching_counts = defaultdict(int)
            for state, actions in self.transitions.items():
                for action, transitions in actions.items():
                    num_successors = len(transitions)
                    if num_successors == 1 and self.small_support_improvement:
                        continue
                    if num_successors == 2 and self.small_support_improvement:
                        num_successors = 1
                    branching_counts[num_successors] += 1
            delta_min = (
                self.total_delta / num_probabilistic_transitions
                if self.small_support_improvement
                else self.total_delta / num_transitions
            )
            delta_max = self.total_delta
            # to avoid float issues, we compute the log of both sides:
            # sum_{s,a} ln(1 - |Post(s,a)| * delta_t) >= ln(1-delta)
            rhs = math.log(1 - self.total_delta)
            while delta_max - delta_min >= 10e-10:
                delta_mid = (delta_max + delta_min) / 2
                # catch cases where delta is too large beforehand
                if any(
                    1 - branching_factor * delta_mid <= 0
                    for branching_factor, count in branching_counts.items()
                ):
                    delta_max = delta_mid
                    continue
                lhs = sum(
                    count * math.log(1 - branching_factor * delta_mid)
                    for branching_factor, count in branching_counts.items()
                )
                if lhs >= rhs:
                    delta_min = delta_mid
                else:
                    delta_max = delta_mid
            # set the deltas
            for state, actions in self.transitions.items():
                if state in self.sink_states:
                    continue
                for action, transitions in actions.items():
                    if len(transitions) == 1 and self.small_support_improvement:
                        for transition in transitions:
                            self.transition_deltas[state][action][
                                transition.successor
                            ] = 1
                    for transition in transitions:
                        self.transition_deltas[state][action][transition.successor] = (
                            delta_min
                        )
                        if len(transitions) == 2:
                            self.transition_deltas[state][action][
                                transition.successor
                            ] *= 2
        elif (
            self.delta_distribution == DeltaDistribution.MINIMIZE_INTERVAL_SUM
            or self.delta_distribution
            == DeltaDistribution.MINIMIZE_INTERVAL_SUM_WITH_JACOBIAN
        ):
            # build a list of (probabilistic) transitions so we have the delta-vector as a one-dimensional object
            transition_list = []
            for state, actions in self.transitions.items():
                if state in self.sink_states:
                    continue
                for action, transitions in actions.items():
                    if len(transitions) == 1:
                        continue
                    for transition in transitions:
                        transition_list.append((state, action, transition.successor))

            # define the objective function
            def sum_of_interval_sizes(d):
                s = 0
                for j in range(len(transition_list)):
                    state, action, succ = transition_list[j]
                    n = sum(self.samples[state][action].values())
                    k = self.samples[state][action][succ]
                    transition_delta = d[j]
                    if (
                        self.small_support_improvement
                        and len(self.transitions[state][action]) == 2
                    ):
                        transition_delta *= 2
                    lower, upper = self._get_probability_bounds(n, k, transition_delta)
                    s += upper - lower
                return s

            # define the jacobian of the objective function
            def jacobian(d):
                result = []
                for j, (state, action, succ) in enumerate(transition_list):
                    num_samples = sum(self.samples[state][action].values())
                    num_successes = self.samples[state][action][succ]
                    transition_delta = (
                        d[j] * 2
                        if len(self.transitions[state][action]) == 2
                        and self.small_support_improvement
                        else d[j]
                    )
                    result.append(
                        self._get_derivative(
                            transition_delta, num_samples, num_successes
                        )
                    )
                return np.array(result)

            # initial guess is uniform, half all deltas to avoid float imprecisions
            initial_guess = (
                1
                / 2
                * self.total_delta
                * np.ones(len(transition_list))
                / len(transition_list)
            )
            # constraint that the sum on all deltas equals total delta
            cons = LinearConstraint(
                lb=0,
                ub=self.total_delta,
                A=np.ones(len(transition_list)),
                keep_feasible=True,
            )
            # all deltas are bounded between 0 and the total delta
            bounds = Bounds(lb=0, ub=self.total_delta, keep_feasible=True)
            # set jacobian explicitly if specified by configuration
            jac = (
                jacobian
                if self.delta_distribution
                == DeltaDistribution.MINIMIZE_INTERVAL_SUM_WITH_JACOBIAN
                else None
            )
            # do the minimization, for possible parameters see the documentation of the default method SLSQP
            # https://docs.scipy.org/doc/scipy/reference/optimize.minimize-slsqp.html#optimize-minimize-slsqp
            res = minimize(
                sum_of_interval_sizes,
                initial_guess,
                bounds=bounds,
                constraints=cons,
                jac=jac,
                method="trust-constr",
            )
            print(
                f"Finished delta distribution. SciPy minimize did {res.nit} iterations with a total of "
                f"{res.nfev} function evaluations and {res.njev} Jacobian evaluations. Sum of deltas: {sum(res.x)}."
            )
            deltas = res.x
            for i, delta in enumerate(deltas):
                state, action, successor = transition_list[i]
                if (
                    len(self.transitions[state][action]) == 2
                    and self.small_support_improvement
                ):
                    delta *= 2
                self.transition_deltas[state][action][successor] = max(0, delta)
        elif (
            self.delta_distribution == DeltaDistribution.MINIMIZE_INTERVAL_SUM_HEURISTIC
        ):
            # fit c*(x+200)^(-0.75) where c is normalizing s.t. sum of deltas = total delta
            delta_sum = 0
            for state, actions in self.samples.items():
                for action, samples in actions.items():
                    if len(samples) == 1:
                        continue
                    total_samples = sum(samples.values())
                    transition_delta = (total_samples + 200) ** (-0.75)
                    if (
                        len(self.transitions[state][action]) == 2
                        and self.small_support_improvement
                    ):
                        transition_delta *= 2
                    for successor in samples.keys():
                        self.transition_deltas[state][action][successor] = (
                            transition_delta
                        )
                        delta_sum += transition_delta
            # normalize
            if delta_sum == 0:
                return
            normalization_factor = self.total_delta / delta_sum
            delta_sum = 0
            for state, actions in self.samples.items():
                for action, samples in actions.items():
                    if len(samples) == 1:
                        continue
                    for successor in samples.keys():
                        self.transition_deltas[state][action][successor] *= (
                            normalization_factor
                        )
                        delta_sum += self.transition_deltas[state][action][successor]
        elif (
            self.delta_distribution == DeltaDistribution.MINIMIZE_VALUE_INTERVAL
            or self.delta_distribution
            == DeltaDistribution.MINIMIZE_VALUE_INTERVAL_WITH_JACOBIAN
        ):
            # build a list of (probabilistic) transitions so we have the delta-vector as a one-dimensional object
            transition_list = []
            for state, actions in self.transitions.items():
                if state in self.sink_states:
                    continue
                for action, transitions in actions.items():
                    if len(transitions) == 1:
                        continue
                    for transition in transitions:
                        transition_list.append((state, action, transition.successor))

            # objective function
            def value_interval_initial_state(d):
                for j, (state, action, succ) in enumerate(transition_list):
                    trans_delta = (
                        d[j] * 2
                        if len(self.transitions[state][action]) == 2
                        and self.small_support_improvement
                        else d[j]
                    )
                    self.transition_deltas[state][action][succ] = trans_delta
                self.calculate_probability_bounds()
                self.evaluation.calculate_bounds(min_property=self.min_property)
                interval_size_sum = sum(
                    self.values[s].upper - self.values[s].lower
                    for s in self.initial_states
                )
                return interval_size_sum / len(self.initial_states)

            # define the jacobian of the objective function
            def jacobian(d):
                result = []
                for j, (state, action, succ) in enumerate(transition_list):
                    num_samples = sum(self.samples[state][action].values())
                    num_successes = self.samples[state][action][succ]
                    du, dl = self._get_derivative(
                        d[j], num_samples, num_successes, separate=True
                    )
                    p_upper = None
                    p_lower = None
                    for t in self.transitions[state][action]:
                        if t.successor == succ:
                            p_lower = t.lower
                            p_upper = t.upper
                    # Upper Bound #
                    upper_derivative = 0
                    # first, check if action is taken in optimistic strategy. if not, leave derivative to 0
                    if self.qualities[state][action].upper >= self.values[state].upper:
                        # otherwise we calculate the derivative of the value interval size
                        # this us done by a value iteration where a "reward" of dp(s,a,s')/ddelta_(s,a,s') * V_opt(s')
                        # is given to s
                        derivatives = {s: 0 for s in self.states}
                        # whether dp(s,a,s')/ddelta_(s,a,s') is du or dl depends on if the optimal transition matrix
                        # uses the lower or upper or lower bound of the probability interval
                        dp = (
                            du
                            if self.optimistic_transition_matrix[state][action][succ]
                            == p_upper
                            else (
                                dl
                                if self.optimistic_transition_matrix[state][action][
                                    succ
                                ]
                                == p_lower
                                else 0
                            )
                        )
                        derivatives[state] = dp * self.values[succ].upper
                        # to do the VI, we need to first get the optimal strategy
                        optimal_strategy = {
                            s: max(
                                self.qualities[s],
                                key=lambda a: self.qualities[s][a].upper,
                            )
                            for s, actions in self.actions.items()
                        }
                        # then we do a standard value iteration on that strategy
                        max_change = float("inf")
                        while max_change > GreyUTMDP.VALUE_ITERATION_CUTOFF:
                            max_change = 0
                            for s in self.states:
                                if s in self.sink_states:
                                    continue
                                new_value = sum(
                                    probability * derivatives[succ]
                                    for succ, probability in self.optimistic_transition_matrix[
                                        s
                                    ][optimal_strategy[s]].items()
                                )
                                if s == state:
                                    new_value += dp * self.values[succ].upper
                                diff = abs(derivatives[s] - new_value)
                                max_change = max(max_change, diff)
                                derivatives[s] = new_value
                        upper_derivative = sum(
                            derivatives[s] for s in self.initial_states
                        ) / len(self.initial_states)
                    # Lower Bound #
                    lower_derivative = 0
                    if self.qualities[state][action].lower >= self.values[state].lower:
                        derivatives = {s: 0 for s in self.states}
                        # whether dp(s,a,s')/ddelta_(s,a,s') is du or dl depends on if the optimal transition matrix
                        # uses the lower or upper or lower bound of the probability interval
                        dp = (
                            du
                            if self.optimistic_transition_matrix[state][action][succ]
                            == p_upper
                            else (
                                dl
                                if self.optimistic_transition_matrix[state][action][
                                    succ
                                ]
                                == p_lower
                                else 0
                            )
                        )
                        derivatives[state] = dp * self.values[succ].lower
                        optimal_strategy = {
                            s: max(
                                self.qualities[s],
                                key=lambda a: self.qualities[s][a].lower,
                            )
                            for s, actions in self.actions.items()
                        }
                        max_change = float("inf")
                        while max_change > GreyUTMDP.VALUE_ITERATION_CUTOFF:
                            max_change = 0
                            for s in self.states:
                                if s in self.sink_states:
                                    continue
                                new_value = sum(
                                    probability * derivatives[succ]
                                    for succ, probability in self.pessimistic_transition_matrix[
                                        s
                                    ][optimal_strategy[s]].items()
                                )
                                if s == state:
                                    new_value += dp * self.values[succ].lower
                                diff = abs(derivatives[s] - new_value)
                                max_change = max(max_change, diff)
                                derivatives[s] = new_value
                        lower_derivative = sum(
                            derivatives[s] for s in self.initial_states
                        ) / len(self.initial_states)
                    result.append(upper_derivative - lower_derivative)
                return np.array(result)
                pass

            # initial guess is uniform, half all deltas to avoid float imprecisions
            initial_guess = (
                1
                / 2
                * self.total_delta
                * np.ones(len(transition_list))
                / len(transition_list)
            )
            # constraint that the sum on all deltas equals total delta
            cons = LinearConstraint(
                lb=0,
                ub=self.total_delta,
                A=np.ones(len(transition_list)),
                keep_feasible=True,
            )
            # all deltas are bounded between 0 and the total delta
            bounds = Bounds(lb=0, ub=self.total_delta, keep_feasible=True)
            # set jacobian explicitly if specified by configuration
            jac = (
                jacobian
                if self.delta_distribution
                == DeltaDistribution.MINIMIZE_VALUE_INTERVAL_WITH_JACOBIAN
                else None
            )
            # do the minimization, for possible parameters see the documentation of the default method SLSQP
            # https://docs.scipy.org/doc/scipy/reference/optimize.minimize-slsqp.html#optimize-minimize-slsqp
            res = minimize(
                value_interval_initial_state,
                initial_guess,
                bounds=bounds,
                constraints=cons,
                jac=jac,
                method="trust-constr",
            )
            print(
                f"Finished delta distribution. SciPy minimize did {res.nit} iterations with a total of "
                f"{res.nfev} function evaluations and {res.njev} Jacobian evaluations. Sum of deltas: {sum(res.x)}."
            )
            deltas = res.x
            for i, delta in enumerate(deltas):
                state, action, successor = transition_list[i]
                if (
                    len(self.transitions[state][action]) == 2
                    and self.small_support_improvement
                ):
                    delta *= 2
                self.transition_deltas[state][action][successor] = max(0, delta)
        else:
            raise NotImplementedError(
                "You specified an invalid methods for distributing the error tolerance"
            )

    def calculate_probability_bounds(self):
        for state, actions in self.transitions.items():
            if state in self.sink_states:
                continue
            for action, transitions in actions.items():
                if len(transitions) == 1:
                    continue
                else:
                    for transition in transitions:
                        num_samples = sum(self.samples[state][action].values())
                        num_successes = self.samples[state][action][
                            transition.successor
                        ]
                        delta = self.transition_deltas[state][action][
                            transition.successor
                        ]
                        lower, upper = self._get_probability_bounds(
                            num_samples, num_successes, delta
                        )
                        transition.lower = lower
                        transition.upper = upper

    ##############################
    # SOLVING THE (INTERVAL) MDP #
    ##############################

    def reset_values(self):
        for state, reward in self.rewards.items():
            self.values[state].corresponding_upper = reward
            self.values[state].mean = reward
            self.values[state].corresponding_lower = reward
            self.values[state].upper = reward
            for a in self.actions[state]:
                self.qualities[state][a].lower = reward
                self.qualities[state][a].mean = reward
                self.qualities[state][a].upper = reward

    def build_exact_transition(
        self, ordering, is_optimistic=False, is_pessimistic=False
    ):
        exact_transition_matrix = dict()
        # precompute indices for O(n) speedup
        state_position_dict = {state: i for i, state in enumerate(ordering)}
        for state in self.states:
            if state in self.sink_states:
                continue

            exact_transition_matrix[state] = dict()

            for action in self.actions[state]:
                # build line in transition matrix
                exact_state_transitions = self.get_instantiation(
                    state, action, state_position_dict
                )
                # insert matrix line
                exact_transition_matrix[state][action] = exact_state_transitions
        if is_optimistic:
            self.optimistic_transition_matrix = exact_transition_matrix
        if is_pessimistic:
            self.pessimistic_transition_matrix = exact_transition_matrix
        return exact_transition_matrix

    def sampled_transition_matrix(self):
        exact_transition_matrix = dict()
        for state in self.states:
            if state in self.sink_states:
                continue
            exact_transition_matrix[state] = dict()
            for action in self.actions[state]:
                # build line in transition matrix
                exact_transition_matrix[state][action] = {
                    t.successor: t.mean for t in self.transitions[state][action]
                }
        return exact_transition_matrix

    def get_instantiation(self, state, action, state_ordering):
        """
        calculate the worst-case transition instantiation for a given state-action pair
        @param state: the state for which an instantiation of the transition function is computed
        @param action: the action for which an instantiation of the transition function is computed
        @param state_ordering: a dict {state: int} s.t. a lower value implies there shall be more probability allocated
        to that state
        """
        transitions = self.transitions[state][action]
        # sort successors s.t. from most to least desirable
        sorted_transitions = sorted(
            transitions, key=lambda t: state_ordering[t.successor]
        )
        exact_state_transitions = dict()
        # allocate minimum probabilities
        for transition in transitions:
            exact_state_transitions[transition.successor] = transition.lower
        remaining_total = 1 - sum(exact_state_transitions.values())
        for transition in sorted_transitions:
            remaining_successor = transition.upper - transition.lower
            if remaining_total > remaining_successor:
                exact_state_transitions[transition.successor] = transition.upper
                remaining_total -= remaining_successor
            else:
                exact_state_transitions[transition.successor] = (
                    transition.lower + remaining_total
                )
                break
        return exact_state_transitions

    def evaluate_strategy(self, strategy):
        # build transition matrix
        transition_matrix = dict()

        for s in self.states:
            if s in self.sink_states:
                continue
            transition_matrix[s] = defaultdict(float)
            action_distribution = strategy[s]
            action_weight_sum = sum(action_distribution.values())
            for act, weight in action_distribution.items():
                action_probability = weight / action_weight_sum
                for transition in self.transitions[s][act]:
                    successor_probability = action_probability * transition.true
                    transition_matrix[s][transition.successor] += successor_probability

        # value iteration
        expected_reward = {s: r for s, r in self.rewards.items()}
        max_change = 1
        while max_change > self.VALUE_ITERATION_CUTOFF:
            max_change = 0
            # loop over all states
            for state in self.states:
                # ignore goal states as their reward is correctly set to their reward
                if state in self.sink_states:
                    continue
                updated_value = self.rewards[state] + sum(
                    expected_reward[successor] * probability
                    for successor, probability in transition_matrix[state].items()
                )
                max_change = max(
                    max_change, abs(expected_reward[state] - updated_value)
                )
                expected_reward[state] = updated_value
        return expected_reward

    def get_optimal_strategy(self, max_iterations=float("inf")):
        # value iteration
        expected_reward = {s: r for s, r in self.rewards.items()}
        optimal_strategy = {s: None for s in self.states}
        max_change = 1
        iteration = 0
        while max_change > self.VALUE_ITERATION_CUTOFF and iteration <= max_iterations:
            iteration += 1
            max_change = 0
            # loop over all states
            for state in self.states:
                # ignore goal states as their reward is correctly set to their reward
                if state in self.sink_states:
                    continue
                max_action_value = float("-inf")
                best_action_in_state = None
                for act in self.actions[state]:
                    act_value = sum(
                        expected_reward[transition.successor] * transition.true
                        for transition in self.transitions[state][act]
                    )
                    if act_value > max_action_value:
                        max_action_value = act_value
                        best_action_in_state = act
                max_action_value = self.rewards[state] + max_action_value
                optimal_strategy[state] = best_action_in_state
                max_change = max(
                    max_change, abs(expected_reward[state] - max_action_value)
                )
                expected_reward[state] = max_action_value
        return expected_reward, optimal_strategy


class UncertainTransition:
    def __init__(self, successor):
        self.lower = 0
        self.mean = 0
        self.true = 0
        self.upper = 1
        self.successor = successor


class UncertainValue:
    def __init__(self):
        self.lower = 0
        # an upper bound for the true value of self.lower used in VI
        self.lower_ub = 1
        # the upper bound of the value under the strategy that yields the highest lower bound
        self.corresponding_upper = 1
        self.mean = 0
        self.upper = 1
        # analogous to lower_ub
        self.upper_lb = 0
        # analogous to corresponding upper
        self.corresponding_lower = 1
