"""
wrapper class that allows exploration of MDPs
"""

from models.Environment import Environment, GreyBoxEnvironment
import random
import json

try:
    import stormpy
except ModuleNotFoundError:
    print(
        "It seems you have not installed stormpy. Parsing MDP files will not be possible;"
        "only MDPs implemented natively in  Python will be available."
    )


class Simulator(Environment):
    def __init__(self, mdp_file):
        # build MDP
        program = stormpy.parse_prism_program(mdp_file)
        options = stormpy.BuilderOptions()
        options.set_add_out_of_bounds_state()
        options.set_build_state_valuations()
        options.set_build_choice_labels()
        options.set_build_all_reward_models()
        self.mdp = stormpy.build_sparse_model_with_options(program, options)

        self.initial_states = []
        self.state_id_map = dict()

        # self.mdp.states produces *different* objects every time it is called
        # here, we collect a dict that stores all state-id pairs so we always refer to the same objects
        for s in self.mdp.states:
            if s.id in self.mdp.initial_states:
                self.initial_states.append(s)
            self.state_id_map[s.id] = s

        # print(self.mdp)

        self.state = random.choice(self.initial_states)

    def get_initial_states(self):
        return self.initial_states

    def is_final_state(self, state):
        return "goal" in self.mdp.labeling.get_labels_of_state(state)

    def get_reward(self, state):
        r = 0
        for name, rewards in self.mdp.reward_models.items():
            r += rewards.get_state_reward(state)
        return r

    def state_id_to_name(self, s_id):
        valuations = self.mdp.state_valuations
        return json.loads(valuations.get_json(s_id))

    def restart(self):
        s = random.choice(self.initial_states)
        self._set_state_by_id(s.id)

    def get_current_state(self):
        return self.state

    def _set_state_by_id(self, state_id):
        self.state = self._get_state_by_id(state_id)

    def _get_state_by_id(self, state_id):
        return self.state_id_map[state_id]

    def print_state(self):
        print("You are in state " + str(self.state))
        print("Your possible actions are:")
        for i, act_id in enumerate(self.get_current_actions()):
            print(
                "Option " + str(i) + ": " + self.act_id_to_name(self.state.id, act_id)
            )

    def get_actions(self, state):
        # since the SparseModelAction object is created on the fly we need to refer to actions by their id
        return [a.id for a in state.actions]

    def get_current_actions(self):
        # since the SparseModelAction object is created on the fly we need to refer to actions by their id
        return [a.id for a in self.state.actions]

    def _get_action_by_id(self, state_id, act_id):
        s = self._get_state_by_id(state_id)
        for a in s.actions:
            if a.id == act_id:
                return a

    def act_id_to_name(self, s_id, act_id):
        a_id = self.mdp.get_choice_index(s_id, act_id)
        if len(self.mdp.choice_labeling.get_labels_of_choice(a_id)) == 0:
            return ""
        (act_label,) = self.mdp.choice_labeling.get_labels_of_choice(a_id)
        return act_label

    def perform_action(self, action_id):
        # we could use random.choices() here but it is slower by a factor of ~2
        r = random.random()
        t_sum = 0
        t = None
        for tgt, prob in self.get_transitions(self.state, action_id):
            t = tgt
            t_sum += prob
            if t_sum > r:
                break
        s = self._get_state_by_id(t)
        self.state = s
        return "goal" in self.mdp.labeling.get_labels_of_state(s)

    def get_transitions(self, state, action_id):
        action = self._get_action_by_id(state.id, action_id)
        for t in action.transitions:
            yield t.column, t.value()


class GreyBoxSimulator(Simulator, GreyBoxEnvironment):
    """
    Environment in which the agent learns while knowing all possible states and transitions (without probabilities)
    beforehand.
    """

    def get_successors(self, state, action_id):
        # the action parameter
        action = self._get_action_by_id(state.id, action_id)
        return [self._get_state_by_id(t.column) for t in action.transitions]

    def get_successors_with_probabilities(self, state, action_id):
        # the action parameter
        action = self._get_action_by_id(state.id, action_id)
        return {self._get_state_by_id(t.column): t.value() for t in action.transitions}
