import abc


class Evaluation(metaclass=abc.ABCMeta):
    def __init__(
        self, model, do_lower=True, do_upper=True, wait_for_corr_convergence=True
    ):
        self.model = model
        self.do_lower = do_lower
        self.do_upper = do_upper
        self.wait_for_corr_convergence = wait_for_corr_convergence

    def calculate_mean_values(self, max_iterations=float("inf")):
        mean_values = {s: self.model.values[s].mean for s in self.model.states}

        # build transition matrix according to ordering, most transition probability into minimal values
        etm = self.model.sampled_transition_matrix()
        iterations = 0
        while True:
            iterations += 1
            # do VI-step, choosing the maximizing action
            mean_values, _, max_main_change, _ = self.value_iteration_step(
                mean_values, etm
            )
            if (
                max_main_change < self.model.VALUE_ITERATION_CUTOFF
                or iterations > max_iterations
            ):
                break

        for s in self.model.states:
            self.model.values[s].mean = mean_values[s]

    def calculate_qualities(self):
        # precompute orders of states by lower value bounds
        ordering_lower = sorted(
            list(self.model.states),
            key=lambda s: self.model.values[s].lower,
            reverse=False,
        )
        ordering_lower = {s: ordering_lower.index(s) for s in ordering_lower}
        ordering_upper = sorted(
            list(self.model.states),
            key=lambda s: self.model.values[s].upper,
            reverse=True,
        )
        ordering_upper = {s: ordering_upper.index(s) for s in ordering_upper}
        etm_mean = self.model.sampled_transition_matrix()

        # calculate policy from lower bounds of the value functions
        for state, actions in self.model.actions.items():
            if state in self.model.sink_states:
                continue

            for action in self.model.actions[state]:
                # calculate worst case transition probabilities from state ordering
                exact_action_transitions_lower = self.model.get_instantiation(
                    state, action, ordering_lower
                )
                exact_action_transitions_upper = self.model.get_instantiation(
                    state, action, ordering_upper
                )
                exact_action_transitions_mean = etm_mean[state][action]

                # calculate min reward
                min_act_reward = self.model.rewards[state]
                for successor, probability in exact_action_transitions_lower.items():
                    value = self.model.values[successor].lower
                    min_act_reward += probability * value
                # calculate max reward
                max_act_reward = self.model.rewards[state]
                for successor, probability in exact_action_transitions_upper.items():
                    value = self.model.values[successor].upper
                    max_act_reward += probability * value
                # calculate mean reward
                mean_act_reward = self.model.rewards[state]
                for successor, probability in exact_action_transitions_mean.items():
                    value = self.model.values[successor].mean
                    mean_act_reward += probability * value

                self.model.qualities[state][action].lower = min_act_reward
                self.model.qualities[state][action].upper = max_act_reward
                self.model.qualities[state][action].mean = mean_act_reward

    def value_iteration_step(
        self, values, etm, corr_values=None, corr_etm=None, maximizing_action=True
    ):
        """
        Perform a VI step according to Bellman equations

        @param values: the state-values to be updated
        @param etm: exact transition matrix - what transition probabilities shall be used?
        @param corr_values: correlated state-values. these state values are updated according to the same action
            choices as the main values. this means these are not maximized for the entire MDP but only under one policy
        @param corr_etm: exact transition matrix for correlated state-values. these may differ from the main etm, e.g.,
            the main etm may be optimistic while the corr_etm may be pessimistic (or vice versa)
        @param maximizing_action: whether to pick the maximizing action in each state. if se4t to False, the minimizing
            action is picked, i.e., a worst-case analysis is performed.
        @return: state-values, correlated state-values, absolute value of maximum change in state-values,
            absolute value of maximum change in correlated state-values
        """
        if (corr_values is None) ^ (corr_etm is None):
            raise ValueError(
                "If correlated state values or the correlated exact transition matrix are given,"
                "the other one has to be present as well"
            )
        max_main_change = 0
        max_corr_change = 0
        # loop over all states
        for state, actions in etm.items():
            # ignore goal states as their reward is correctly set to their reward
            if state in self.model.sink_states:
                continue
            opt_action_value = float("-inf") if maximizing_action else float("inf")
            corr_action_value = 0
            for act, transitions in actions.items():
                act_value = sum(
                    values[successor] * probability
                    for successor, probability in transitions.items()
                )
                new_opt_action = (
                    (act_value > opt_action_value)
                    if maximizing_action
                    else (act_value < opt_action_value)
                )
                if new_opt_action:
                    opt_action_value = act_value
                    if corr_values is not None:
                        corr_action_value = sum(
                            corr_values[successor] * probability
                            for successor, probability in corr_etm[state][act].items()
                        )
            opt_action_value = self.model.rewards[state] + opt_action_value
            if corr_values is not None:
                corr_action_value = self.model.rewards[state] + corr_action_value
            max_main_change = max(
                max_main_change, abs(values[state] - opt_action_value)
            )
            if corr_values is not None:
                max_corr_change = max(
                    max_corr_change, abs(corr_values[state] - corr_action_value)
                )
            values[state] = opt_action_value
            if corr_values is not None:
                corr_values[state] = corr_action_value
        return values, corr_values, max_main_change, max_corr_change

    def calculate_bounds(
        self, max_iterations=float("inf"), min_property=False, keep_old_bounds=False
    ):
        is_max_property = not min_property
        # initialize all values
        min_values = dict()
        if keep_old_bounds:
            min_values = {
                s: (
                    self.model.rewards[s]
                    if s in self.model.sink_states
                    else self.model.values[s].lower
                )
                for s in self.model.states
            }
        else:
            min_values = {
                s: (self.model.rewards[s] if s in self.model.sink_states else 0)
                for s in self.model.states
            }
        min_values_ub = {
            s: (self.model.rewards[s] if s in self.model.sink_states else 1)
            for s in self.model.states
        }

        # lower bound
        iterations = 0
        skip = False
        while True:
            iterations += 1
            # build transition matrix according to ordering, most transition probability into minimal values
            ordering = sorted(
                list(self.model.states), key=lambda s: min_values[s], reverse=False
            )
            etm = self.model.build_exact_transition(ordering, is_pessimistic=True)
            # do VI-step, choosing the maximizing action
            new_min_values, _, _, _ = self.value_iteration_step(
                min_values, etm, maximizing_action=is_max_property
            )
            if all(v == new_min_values[s] for s, v in min_values.items()):
                skip = True
            min_values = new_min_values

            # build transition matrix according to ordering, most transition probability into minimal values
            ordering_ub = sorted(
                list(self.model.states), key=lambda s: min_values_ub[s], reverse=False
            )
            etm_ub = self.model.build_exact_transition(ordering_ub, is_pessimistic=True)
            # do VI-step, choosing the maximizing action
            min_values_ub, _, _, _ = self.value_iteration_step(
                min_values_ub, etm_ub, maximizing_action=is_max_property
            )

            total_diff = sum(
                min_values_ub[s] - min_values[s] for s in self.model.initial_states
            )
            diff = total_diff / len(self.model.initial_states)

            if (
                skip
                or diff < self.model.VALUE_ITERATION_CUTOFF
                or iterations >= max_iterations
            ):
                break

        # initialize all values
        if keep_old_bounds:
            max_values = {
                s: (
                    self.model.rewards[s]
                    if s in self.model.sink_states
                    else self.model.values[s].upper
                )
                for s in self.model.states
            }
        else:
            max_values = {
                s: (self.model.rewards[s] if s in self.model.sink_states else 1)
                for s in self.model.states
            }
        max_values_lb = {
            s: (self.model.rewards[s] if s in self.model.sink_states else 0)
            for s in self.model.states
        }
        # upper bound
        iterations = 0
        skip = False
        while True:
            iterations += 1
            # build transition matrix according to ordering, most transition probability into maximal values
            ordering = sorted(
                list(self.model.states), key=lambda s: max_values[s], reverse=True
            )
            etm = self.model.build_exact_transition(ordering, is_optimistic=True)
            # do VI-step, choosing the maximizing action.
            new_max_values, _, _, _ = self.value_iteration_step(
                max_values, etm, maximizing_action=is_max_property
            )
            if all(v == new_max_values[s] for s, v in max_values.items()):
                skip = True
            max_values = new_max_values

            # build transition matrix according to ordering, most transition probability into maximal values
            ordering_lb = sorted(
                list(self.model.states), key=lambda s: max_values_lb[s], reverse=True
            )
            etm_lb = self.model.build_exact_transition(ordering_lb, is_optimistic=True)
            # do VI-step, choosing the maximizing action
            max_values_lb, _, _, _ = self.value_iteration_step(
                max_values_lb, etm_lb, maximizing_action=is_max_property
            )

            total_diff = sum(
                max_values[s] - max_values_lb[s] for s in self.model.initial_states
            )
            diff = total_diff / len(self.model.initial_states)

            if (
                skip
                or diff < self.model.VALUE_ITERATION_CUTOFF
                or iterations >= max_iterations
            ):
                break

        for s in self.model.states:
            self.model.values[s].lower = min_values[s]
            self.model.values[s].upper = max_values[s]
            self.model.values[s].lower_ub = min_values_ub[s]
            self.model.values[s].upper_lb = max_values_lb[s]

    def calculate_bounds_with_corr_bounds(
        self, max_iterations=float("inf"), min_property=False
    ):
        is_max_property = not min_property
        # initialize all values
        min_values = {
            s: (self.model.rewards[s] if s in self.model.sink_states else 0)
            for s in self.model.states
        }
        corr_max_values = {
            s: (self.model.rewards[s] if s in self.model.sink_states else 1)
            for s in self.model.states
        }
        if self.do_lower:
            # lower bound
            iterations = 0
            while True:
                iterations += 1
                # build transition matrix according to ordering, most transition probability into minimal values
                ordering = sorted(
                    list(self.model.states), key=lambda s: min_values[s], reverse=False
                )
                etm = self.model.build_exact_transition(ordering, is_pessimistic=True)
                corr_ordering = sorted(
                    list(self.model.states),
                    key=lambda s: corr_max_values[s],
                    reverse=True,
                )
                corr_etm = self.model.build_exact_transition(corr_ordering)
                # do VI-step, choosing the maximizing action
                min_values, corr_max_values, max_main_change, max_corr_change = (
                    self.value_iteration_step(
                        min_values,
                        etm,
                        corr_max_values,
                        corr_etm,
                        maximizing_action=is_max_property,
                    )
                )
                max_change = (
                    max(max_main_change, max_corr_change)
                    if self.wait_for_corr_convergence
                    else max_main_change
                )
                if (
                    max_change < self.model.VALUE_ITERATION_CUTOFF
                    or iterations >= max_iterations
                ):
                    break

        max_values = {
            s: (self.model.rewards[s] if s in self.model.sink_states else 1)
            for s in self.model.states
        }
        corr_min_values = {
            s: (self.model.rewards[s] if s in self.model.sink_states else 0)
            for s in self.model.states
        }

        if self.do_upper:
            # upper bound
            iterations = 0
            while True:
                iterations += 1
                # build transition matrix according to ordering, most transition probability into minimal values
                ordering = sorted(
                    list(self.model.states), key=lambda s: max_values[s], reverse=True
                )
                etm = self.model.build_exact_transition(ordering, is_pessimistic=True)
                corr_ordering = sorted(
                    list(self.model.states),
                    key=lambda s: corr_min_values[s],
                    reverse=False,
                )
                corr_etm = self.model.build_exact_transition(corr_ordering)
                # do VI-step, choosing the maximizing action
                max_values, corr_min_values, max_main_change, max_corr_change = (
                    self.value_iteration_step(
                        max_values,
                        etm,
                        corr_min_values,
                        corr_etm,
                        maximizing_action=is_max_property,
                    )
                )
                max_change = (
                    max(max_main_change, max_corr_change)
                    if self.wait_for_corr_convergence
                    else max_main_change
                )
                if (
                    max_change < self.model.VALUE_ITERATION_CUTOFF
                    or iterations >= max_iterations
                ):
                    break

        for s in self.model.states:
            self.model.values[s].lower = min_values[s]
            self.model.values[s].corresponding_upper = corr_max_values[s]
            self.model.values[s].upper = max_values[s]
            self.model.values[s].corresponding_lower = corr_min_values[s]
