import copy
import random

from .Evaluation import *


class Strategy:
    def __init__(self, model):
        self.model = model
        # for every state, save possible actions and a weight for each action
        self.policy = dict()
        self.unexplored = {
            state: copy.copy(actions) for state, actions in self.model.actions.items()
        }
        # keep track of the episode number
        self.passed_episodes = 0
        self.evaluation_method = Evaluation(model)

    def set_policy_random(self):
        for state, actions in self.model.actions.items():
            self.policy[state] = dict()
            num_actions = len(actions)
            if num_actions == 1:
                self.policy[state][actions[0]] = 1
            separators = [random.random() for _ in range(num_actions - 1)]
            separators.append(1)
            separators.append(0)
            separators.sort()
            for i, action in enumerate(actions):
                self.policy[state][action] = separators[i + 1] - separators[i]

    def set_policy_uniform(self):
        for state, actions in self.model.actions.items():
            self.policy[state] = dict()
            for i, action in enumerate(actions):
                self.policy[state][action] = 1 / len(actions)

    def choice(self, state, actions):
        if self.unexplored[state]:
            choice = random.choice(self.unexplored[state])
            self.unexplored[state].remove(choice)
            return choice
        if state not in self.policy:
            return random.choice(actions)
        actions, weights = list(self.policy[state].keys()), self.policy[state].values()
        choice = random.choices(population=actions, weights=weights, k=1)
        return choice[0]

    def get_deterministic_policy(self):
        return {
            state: {max(distribution, key=distribution.get): 1}
            for state, distribution in self.policy.items()
        }

    def update_memory(self, new_experience):
        # incorporate new samples
        for state, actions in new_experience.items():
            for action, results in actions.items():
                for successor, num in results.items():
                    # if action was removed because it behaves the same as another action, count sample
                    # as if the equivalent representative was sampled
                    while (state, action) in self.model.action_mapping:
                        action = self.model.action_mapping[(state, action)]
                    self.model.samples[state][action][successor] += num
                    if state not in self.model.states:
                        print(
                            f"Warning: Added sample from state {state} that does not exist in model"
                        )
                    elif action not in self.model.actions[state]:
                        print(
                            f"Warning: Added sample for action {action} that does not exist in state {state}"
                        )
                    elif successor not in [
                        t.successor for t in self.model.transitions[state][action]
                    ]:
                        print(
                            f"Warning: Added sample to successor {successor} that is not reachable from state {state} via action {action}"
                        )

    def update_values(
        self,
        only_probabilities=False,
        keep_old_bounds=False,
        with_corr_bounds=True,
        max_iterations=float("inf"),
    ):
        # update MDP delta distribution, probabilities and possibly values
        self.model.calculate_delta()
        self.model.calculate_probability_bounds()
        if not only_probabilities:
            if with_corr_bounds:
                self.evaluation_method.calculate_bounds_with_corr_bounds(
                    max_iterations=max_iterations, min_property=self.model.min_property
                )
            else:
                self.evaluation_method.calculate_bounds(
                    max_iterations=max_iterations,
                    min_property=self.model.min_property,
                    keep_old_bounds=keep_old_bounds,
                )
