import dataclasses as d
from .UTMDP import *
from collections import defaultdict


def mdp_from_tra(path_to_tra, path_to_lab):
    model = GreyUTMDP()
    transitions = defaultdict(lambda: defaultdict(lambda: defaultdict(float)))
    labels = defaultdict(list)
    with open(path_to_tra, "r") as f:
        # header
        f.readline()
        for line in f:
            state, action, successor, probability, *_ = line.strip().split(" ")
            probability = float(probability.replace(",", "."))
            transitions[state][action][successor] = probability
    with open(path_to_lab, "r") as f:
        label_names = f.readline().strip().split(" ")
        label_name_dict = {}
        for label in label_names:
            label_num, label_name = label.split("=")
            label_name_dict[label_num] = label_name.replace('"', "").replace("'", "")
        for line in f:
            state, state_labels = line.strip().split(": ")
            state_labels = state_labels.split(" ")
            state_labels = [label_name_dict[sl] for sl in state_labels]
            labels[state].extend(state_labels)
    for state, actions in transitions.items():
        reward = 1 if "goal" in labels[state] else 0
        model.add_state(state, list(actions.keys()), reward, False)
        if "init" in labels[state]:
            model.initial_states.append(state)
        only_self_successors = True
        for action, transitions in actions.items():
            model.transitions[state][action] = set()
            for successor, probability in transitions.items():
                if successor != state:
                    only_self_successors = False
                t = UncertainTransition(successor)
                t.true = probability
                model.transitions[state][action].add(t)
        if only_self_successors or reward == 1 or "sink" in labels[state]:
            model.sink_states.add(state)
    return model


def empirically_evaluate_strategy(environment, strategy, n):
    reward = 0
    for j in range(n):
        goal_state_reached = False
        while not goal_state_reached:
            state = environment.get_current_state()
            reward += environment.get_reward(state)
            actions = environment.get_current_actions()
            choice = strategy.choice(state, actions)
            goal_state_reached = environment.perform_action(choice)
            # track experience
            new_state = environment.get_current_state()
            assert new_state in environment.get_successors(state, choice)
        goal_state = environment.get_current_state()
        reward += environment.get_reward(goal_state)
        environment.restart()
    return reward / n


@d.dataclass(frozen=True)
class SVG:
    source: str


def visualize_racetrack(racetrack, samples):
    tiles = racetrack.track + racetrack.goal_zone
    max_x = max(x for x, y in tiles)
    min_x = min(x for x, y in tiles)
    max_y = max(y for x, y in tiles)
    min_y = min(y for x, y in tiles)
    height_px = 10 * (max_y - min_y + 1)
    width_px = 10 * (max_x - min_x + 1)
    track_svg = [
        '<?xml version="1.0" encoding="UTF-8"?>',
        f"""<svg
                 xmlns="http://www.w3.org/2000/svg"
                 version="1.1" baseProfile="full"
                 width="{width_px}" height="{height_px}"
                 viewBox="{min_x} {min_y} {width_px} {height_px}">""",
    ]

    track_count = defaultdict(int)

    for state, actions in samples.items():
        state_pos, _ = state
        for action, successor_count in actions.items():
            for successor, count in successor_count.items():
                successor_pos, _ = successor
                if (
                    successor_pos == racetrack.OFF_TRACK
                    or successor_pos in racetrack.start_zone
                ):
                    continue
                track_count[successor_pos] += count

    max_track_count = max(track_count.values())

    for x in range(min_x, max_x + 1):
        for y in range(min_y, max_y + 1):
            color = "grey"
            opacity = 1
            if (x, y) in racetrack.track:
                color = "yellow" if (x, y) in racetrack.start_zone else "red"
                opacity = (
                    0.7
                    if (x, y) in racetrack.start_zone
                    else (track_count[x, y] / max_track_count) ** (3 / 5)
                )
            elif (x, y) in racetrack.goal_zone:
                color = "green"
                opacity = 0.7
            track_svg.append(
                f"""<rect
                    x="{x * 10}"
                    y="{y * 10}"
                    width="10"
                    height="10"
                    fill="{color}"
                    stroke="black"
                    stroke-width="1pt"
                    fill-opacity="{opacity}"
                    />
                """
            )
    track_svg.append("</svg>")
    return SVG("\n".join(track_svg))


def get_choice_entropy(environment, samples):
    entropy = dict()
    for state, actions in samples.items():
        # collect number of samples
        num_samples = [
            sum(successor_count.values()) for successor_count in actions.values()
        ]
        # normalize
        total_num_samples = sum(num_samples)
        choice_distribution = [n / total_num_samples for n in num_samples]
        # calculate normalized entropy
        num_choices = len(environment.get_actions(state))
        choice_entropy = (
            0
            if num_choices == 1
            else -sum(p * math.log(p, num_choices) for p in choice_distribution)
        )
        entropy[state] = choice_entropy
    return entropy
