import subprocess
import time
import pathlib

from strategies import Strategy
from strategies.utils import UTMDP, utils
import random
from collections import defaultdict
import copy
import logging
import argparse
import numpy as np
import matplotlib.pyplot as plt


def prepare_model(
    environment,
    path_to_mdp_file,
    delta,
    collapse_mecs,
    use_objective_for_mec=True,
    property_name=None,
    is_min_property=False,
):
    #########
    # Model #
    #########
    model = UTMDP.GreyUTMDP(delta=delta, min_property=is_min_property)
    if environment == "MDP":
        dot_index = path_to_mdp_file.rfind(".")
        slash_index = path_to_mdp_file.rfind("/")
        output_path = (
            path_to_mdp_file[:slash_index]
            + "/transformed"
            + path_to_mdp_file[slash_index:dot_index]
        )
        file_name_extension = (
            "" if not collapse_mecs else ("_obj" if use_objective_for_mec else "_mec")
        )
        collapse_param = (
            "NONE" if not collapse_mecs else ("ALL" if use_objective_for_mec else "MEC")
        )
        print("Pre-processing MDP...\n")
        first_dot_index = path_to_mdp_file.find(".")
        property_path = path_to_mdp_file[:first_dot_index] + ".props"
        pathlib.Path(output_path).parent.mkdir(parents=True, exist_ok=True)
        if property_name is None:
            r = subprocess.check_output(
                [
                    "sh",
                    "models/probabilistic-models-1.0/bin/probabilistic-models",
                    "process",
                    "--model",
                    path_to_mdp_file,
                    "--properties",
                    property_path,
                    "--collapse",
                    collapse_param,
                    "--output",
                    output_path + file_name_extension,
                ]
            )
        else:
            r = subprocess.check_output(
                [
                    "sh",
                    "models/probabilistic-models-1.0/bin/probabilistic-models",
                    "process",
                    "--model",
                    path_to_mdp_file,
                    "--properties",
                    property_path,
                    "--property",
                    property_name,
                    "--collapse",
                    collapse_param,
                    "--output",
                    output_path + file_name_extension,
                ]
            )
        # print(r.decode())
        model = utils.mdp_from_tra(
            output_path + file_name_extension + ".tra",
            output_path + file_name_extension + ".lab",
        )
        model.min_property = is_min_property
        print("MDP was successfully built!\n")
    else:
        environment = eval(environment)
        # build model
        model.build_model(environment)
    model.merge_same_target_transitions()
    model.remove_duplicate_transitions()
    model.handle_self_loops()
    return model


def do_learning(
    environment,
    property_name,
    is_min_property,
    path_to_mdp_file,
    path_to_logfile,
    delta,
    num_episodes,
):
    delta_methods_to_use = [
        UTMDP.DeltaDistribution.UNIFORM,
        # UTMDP.DeltaDistribution.UNIFORM_PRODUCT,
        # UTMDP.DeltaDistribution.MINIMIZE_INTERVAL_SUM_WITH_JACOBIAN,
        # UTMDP.DeltaDistribution.MINIMIZE_INTERVAL_SUM_HEURISTIC,
        # UTMDP.DeltaDistribution.MINIMIZE_VALUE_INTERVAL,
        # UTMDP.DeltaDistribution.MINIMIZE_VALUE_INTERVAL_WITH_JACOBIAN
    ]
    confidence_methods_to_use = [
        UTMDP.ConfidenceMethod.HOEFFDING,
        # UTMDP.ConfidenceMethod.CLOPPER_PEARSON
    ]

    #########
    # Model #
    #########
    model = prepare_model(
        environment,
        path_to_mdp_file,
        delta,
        collapse_mecs=True,
        use_objective_for_mec=False,
        property_name=property_name,
        is_min_property=is_min_property,
    )
    strategy = Strategy.Strategy(model=model)
    print_model_parameters(model)

    model.small_support_improvement = False

    ##################
    # Preparing logs #
    ##################

    for delta_distribution_method in delta_methods_to_use:
        for probability_calculation_method in confidence_methods_to_use:
            method_name = (
                delta_distribution_method.name
                + "_"
                + probability_calculation_method.name
            )
            logger = logging.getLogger(method_name)
            handler = logging.FileHandler(
                path_to_logfile[:-4] + "_" + method_name + path_to_logfile[-4:],
                mode="w",
            )
            logger.setLevel(logging.INFO)
            logger.addHandler(handler)
            logger.info(
                "episode \tlower \tc_upper \tc_lower \tupper \tepsilon \ttime\n"
            )

    ##################
    # Gather samples #
    ##################

    num_transitions = sum(
        len(transitions)
        for state, actions in model.transitions.items()
        for action, transitions in actions.items()
    )
    num_prob_transitions = sum(
        len(transitions) if len(transitions) > 1 else 0
        for state, actions in model.transitions.items()
        for action, transitions in actions.items()
    )

    max_num_steps = 5 * num_transitions
    strategy.set_policy_random()

    for i in range(num_episodes):
        print("episode", i + 1, "of", num_episodes)

        new_experience = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
        for j in range(10 * num_prob_transitions):
            # TODO discuss how to acquire data s.t. size is sensible with MDP in question
            state = random.choice(model.initial_states)
            steps = 0
            goal_state_reached = False
            while not goal_state_reached and steps < max_num_steps:
                steps += 1
                actions = model.actions[state]
                choice = strategy.choice(state, actions)
                new_state = model.get_sample(state, choice)
                goal_state_reached = new_state in model.sink_states
                new_experience[state][choice][new_state] += 1
                state = new_state
                #    state = random.choice(model.initial_states)
                # else:
                #    state = new_state
        print("done sampling")
        strategy.update_memory(new_experience)

        ##############
        # Evaluating #
        ##############
        for delta_distribution_method in delta_methods_to_use:
            for probability_calculation_method in confidence_methods_to_use:
                print(
                    f"delta distribution: {delta_distribution_method.name}, "
                    f"probability interval calculation: {probability_calculation_method.name}"
                )
                start_time = time.time()
                model.delta_distribution = delta_distribution_method
                model.confidence_method = probability_calculation_method
                strategy.update_values(only_probabilities=False, keep_old_bounds=True)
                total_time = time.time() - start_time

                initial_states = strategy.model.initial_states
                avg_lower = sum(
                    strategy.model.values[init].lower for init in initial_states
                ) / len(initial_states)
                avg_corr_upper = sum(
                    strategy.model.values[init].corresponding_upper
                    for init in initial_states
                ) / len(initial_states)
                avg_corr_lower = sum(
                    strategy.model.values[init].corresponding_lower
                    for init in initial_states
                ) / len(initial_states)
                avg_upper = sum(
                    strategy.model.values[init].upper for init in initial_states
                ) / len(initial_states)
                sum_epsilon = sum(
                    transition.upper - transition.lower
                    for state, actions in model.transitions.items()
                    for action, transitions in actions.items()
                    for transition in transitions
                )

                # log deltas after final episode
                if i == num_episodes - 1:
                    slash_index = log_path.rfind("/")
                    with open(
                        path_to_logfile[:slash_index]
                        + "/deltas"
                        + path_to_logfile[slash_index:-4]
                        + "_"
                        + delta_distribution_method.name
                        + "_"
                        + probability_calculation_method.name
                        + ".dat",
                        "w",
                    ) as f:
                        for state, actions in model.transitions.items():
                            for action, transitions in actions.items():
                                if len(transitions) == 1:
                                    continue
                                for transition in transitions:
                                    samples = sum(model.samples[state][action].values())
                                    phat = (
                                        model.samples[state][action][
                                            transition.successor
                                        ]
                                        / samples
                                        if samples
                                        else 0
                                    )
                                    delta = model.transition_deltas[state][action][
                                        transition.successor
                                    ]
                                    f.write(
                                        str(samples)
                                        + " "
                                        + str(phat)
                                        + " "
                                        + str(delta)
                                        + "\n"
                                    )

                ###########
                # Logging #
                ###########
                method_name = (
                    delta_distribution_method.name
                    + "_"
                    + probability_calculation_method.name
                )
                logger = logging.getLogger(method_name)
                logger.info(
                    f"{i + 1} \t{avg_lower:.5f} \t{avg_corr_upper:.5f} \t{avg_corr_lower:.5f}\t{avg_upper:.5f} "
                    f"\t{sum_epsilon:.3f} \t{total_time:.4f}"
                )

                print(f"{avg_upper-avg_lower} -- interval [{avg_lower}, {avg_upper}]")


def do_ablation(
    environment,
    property_name,
    is_min_property,
    path_to_mdp_file,
    path_to_logfile,
    delta,
    precision,
    full,
):
    # We check
    # - Baseline (Uniform Hoeffding)
    # - Model-SMC improvements with ablations:
    #   * using Clopper-Pearson interval
    #   * spend no delta in deterministic, and only one delta in 2-successor transitions
    #   * split delta multiplicatively
    #   * chain contraction (eliminate states with only one incoming transition and one action)
    #   * partial NWR quotient (MEC + attractors + essential states)
    #   * structural improvements (both of the above)
    # - extra run for all improvements
    batchsize_factor = 1

    start_time = time.time()
    baseline_model = prepare_model(
        environment,
        path_to_mdp_file,
        delta,
        use_objective_for_mec=False,
        property_name=property_name,
        collapse_mecs=True,
        is_min_property=is_min_property,
    )
    baseline_model.handle_self_loops()
    baseline_model.delta_distribution = UTMDP.DeltaDistribution.UNIFORM
    baseline_model.confidence_method = UTMDP.ConfidenceMethod.HOEFFDING
    baseline_model.small_support_improvement = False
    baseline_strategy = Strategy.Strategy(baseline_model)
    baseline_strategy.set_policy_uniform()
    building_time_baseline = time.time() - start_time

    full_model = copy.deepcopy(baseline_model)
    start_time = time.time()
    full_model.structural_improvements()
    full_model.delta_distribution = UTMDP.DeltaDistribution.UNIFORM_PRODUCT
    full_model.confidence_method = UTMDP.ConfidenceMethod.CLOPPER_PEARSON
    full_model.small_support_improvement = True
    full_strategy = Strategy.Strategy(full_model)
    building_time_full = time.time() - start_time + building_time_baseline

    setups = {
        "baseline": {
            "id": 0,
            "model": baseline_model,
            "strategy": baseline_strategy,
            "building_time": building_time_baseline,
        },
        "full": {
            "id": 1,
            "model": full_model,
            "strategy": full_strategy,
            "building_time": building_time_full,
        },
    }

    if full:
        cp_model = copy.deepcopy(full_model)
        cp_model.confidence_method = UTMDP.ConfidenceMethod.HOEFFDING
        cp_strategy = Strategy.Strategy(cp_model)

        small_supp_model = copy.deepcopy(full_model)
        small_supp_model.small_support_improvement = False
        small_supp_strategy = Strategy.Strategy(small_supp_model)

        independence_model = copy.deepcopy(full_model)
        independence_model.delta_distribution = UTMDP.DeltaDistribution.UNIFORM
        independence_strategy = Strategy.Strategy(independence_model)

        structure_model = copy.deepcopy(baseline_model)
        structure_model.confidence_method = UTMDP.ConfidenceMethod.CLOPPER_PEARSON
        structure_model.delta_distribution = UTMDP.DeltaDistribution.UNIFORM_PRODUCT
        structure_model.small_support_improvement = True
        structure_strategy = Strategy.Strategy(structure_model)

        chain_model = copy.deepcopy(structure_model)
        start_time = time.time()
        chain_model.collapse_prob01()
        chain_model.merge_essential_states()
        chain_strategy = Strategy.Strategy(chain_model)
        building_time_chains = time.time() - start_time + building_time_baseline

        nwr_model = copy.deepcopy(structure_model)
        start_time = time.time()
        nwr_model.contract_chains()
        nwr_strategy = Strategy.Strategy(nwr_model)
        building_time_nwr = time.time() - start_time + building_time_baseline

        setups["cp"] = {
            "id": 2,
            "model": cp_model,
            "strategy": cp_strategy,
            "building_time": building_time_baseline,
        }
        setups["small_supp"] = {
            "id": 3,
            "model": small_supp_model,
            "strategy": small_supp_strategy,
            "building_time": building_time_baseline,
        }
        setups["independence"] = {
            "id": 4,
            "model": independence_model,
            "strategy": independence_strategy,
            "building_time": building_time_baseline,
        }
        setups["chains"] = {
            "id": 5,
            "model": chain_model,
            "strategy": chain_strategy,
            "building_time": building_time_chains,
        }
        setups["nwr"] = {
            "id": 6,
            "model": nwr_model,
            "strategy": nwr_strategy,
            "building_time": building_time_nwr,
        }
        setups["structure"] = {
            "id": 7,
            "model": structure_model,
            "strategy": structure_strategy,
            "building_time": building_time_baseline,
        }

    # for m in [baseline_model, cp_model, small_supp_model, independence_model, chain_model, nwr_model, structure_model, full_model]:
    #    print_model_parameters(m)
    #    print("---")

    for s in setups:
        model = setups[s]["model"]
        setups[s]["num_runs"] = -1
        setups[s]["epsilon"] = 1
        setups[s]["num_transitions"] = sum(
            len(transitions)
            for state, actions in model.transitions.items()
            for action, transitions in actions.items()
        )
        setups[s]["num_prob_transitions"] = sum(
            len(transitions) if len(transitions) > 1 else 0
            for state, actions in model.transitions.items()
            for action, transitions in actions.items()
        )
        setups[s]["batchsize"] = (
            setups[s]["num_transitions"] // batchsize_factor
            if s == "baseline" or s == "small_supp"
            else setups[s]["num_prob_transitions"] // batchsize_factor
        )
        setups[s]["batchsize"] = max(setups[s]["batchsize"], 1)

    print(f"Desired precision: {precision}")
    print("Starting sampling process...\n")

    print(
        "\n".join(
            f"{name} achieved precision 1"
            for name in sorted(setups, key=lambda x: setups[x]["id"])
        )
    )

    # essentially run sample_until_precision() for all models above, but on shared data (i.e. same samples).
    num_runs = 0
    sampling_time = 0

    while True:
        start_time = time.time()
        new_experience = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
        num_runs += 1
        max_num_steps = (
            5 * setups["baseline"]["num_transitions"]
        )  # alternative: num states
        state = random.choice(baseline_model.initial_states)
        steps = 0
        goal_state_reached = False
        while not goal_state_reached and steps < max_num_steps:
            steps += 1
            actions = baseline_model.actions[state]
            choice = baseline_strategy.choice(state, actions)
            new_state = baseline_model.get_sample(state, choice)
            goal_state_reached = new_state in baseline_model.sink_states
            new_experience[state][choice][new_state] += 1
            state = new_state
        sampling_time += time.time() - start_time

        #######################
        # Update and Evaluate #
        #######################

        for name, parts in setups.items():
            # We use num_runs = -1 as a way to identify which methods have not yet achieved precision epsilon
            if parts["num_runs"] == -1:
                model = parts["model"]
                strategy = parts["strategy"]
                if name == "baseline" or name == "structure":
                    strategy.update_memory(new_experience)
                elif name == "nwr":
                    # manually convert samples for chains here since chains has different state/action set
                    # recall that we contract (i.e. remove) state that have
                    # (1) only one incoming transition, and
                    # (2) only one available action
                    # thus, here for all removed states we
                    # (a) ignore incoming transitions (as we do not know where they went), and
                    # (b) set the source for outgoing transitions as the predecessors state-action the removed state got merged into
                    for state, actions in new_experience.items():
                        for action, results in actions.items():
                            for successor, num in results.items():
                                if successor not in model.states:
                                    continue
                                if state not in model.states:
                                    state, action = model.chain_mapping[state]
                                if state == successor:
                                    continue
                                # if action was removed because it behaves the same as another action, count sample
                                # as if the equivalent representative was sampled
                                while (state, action) in model.action_mapping:
                                    action = model.action_mapping[(state, action)]
                                model.samples[state][action][successor] += num
                                if state not in model.states:
                                    print(
                                        f"Warning: Added sample from state {state} that does not exist in model"
                                    )
                                elif successor not in model.states:
                                    print(
                                        f"Warning: Added sample to state {successor} that does not exist in model"
                                    )
                                elif action not in model.actions[state]:
                                    print(
                                        f"Warning: Added sample for action {action} that does not exist in state {state}"
                                    )
                                elif successor not in [
                                    t.successor
                                    for t in model.transitions[state][action]
                                ]:
                                    print(
                                        f"Warning: Added sample to successor {successor} that is not reachable from state {state} via action {action}"
                                    )
                elif name == "chains":
                    # manually convert samples for nwr here since nwr model has different state/action set
                    # recall that this is essential states + prob0/1 states, i.e. those states x where we guarantee
                    # that under the optimal strategy we will reach y. Therefore we
                    # (a) ignore outgoing transitions from x
                    # (b) redirect transitions with x as successor to y
                    for state, actions in new_experience.items():
                        for action, results in actions.items():
                            for successor, num in results.items():
                                # if a state is unreachable this means it is prob0 or prob1
                                # in either case, the samples don't help any method, so we ignore them
                                if state not in model.states:
                                    continue
                                if successor not in model.states:
                                    successor = model.essential_mapping[successor]
                                # if action was removed because it behaves the same as another action, count sample
                                # as if the equivalent representative was sampled
                                while (state, action) in model.action_mapping:
                                    action = model.action_mapping[(state, action)]
                                model.samples[state][action][successor] += num
                                if state not in model.states:
                                    print(
                                        f"Warning: Added sample from state {state} that does not exist in model"
                                    )
                                elif successor not in model.states:
                                    print(
                                        f"Warning: Added sample to state {successor} that does not exist in model"
                                    )
                                elif action not in model.actions[state]:
                                    print(
                                        f"Warning: Added sample for action {action} that does not exist in state {state}"
                                    )
                                elif successor not in [
                                    t.successor
                                    for t in model.transitions[state][action]
                                ]:
                                    print(
                                        f"Warning: Added sample to successor {successor} that is not reachable from state {state} via action {action}"
                                    )
                else:
                    # essentially merge previous two cases; make sure to handle essential case first since that's
                    # the transformation that is done first
                    for state, actions in new_experience.items():
                        for action, results in actions.items():
                            for successor, num in results.items():
                                # if a state is unreachable this means it is prob0 or prob1
                                # in either case, the samples don't help any method, so we ignore them
                                if state in model.removed_unreachable_states:
                                    continue
                                # handle essential states
                                if (
                                    state not in model.states
                                    and state in model.essential_mapping.keys()
                                ):
                                    continue
                                if (
                                    successor not in model.states
                                    and successor in model.essential_mapping.keys()
                                ):
                                    successor = model.essential_mapping[successor]
                                # handle chains
                                if (
                                    successor not in model.states
                                    and successor in model.chain_mapping.keys()
                                ):
                                    continue
                                if (
                                    state not in model.states
                                    and state in model.chain_mapping
                                ):
                                    state, action = model.chain_mapping[state]
                                if state == successor:
                                    continue
                                # if action was removed because it behaves the same as another action, count sample
                                # as if the equivalent representative was sampled
                                while (state, action) in model.action_mapping:
                                    action = model.action_mapping[(state, action)]
                                model.samples[state][action][successor] += num
                                if state not in model.states:
                                    print(
                                        f"Warning: Added sample from state {state} that does not exist in model"
                                    )
                                elif successor not in model.states:
                                    print(
                                        f"Warning: Added sample to state {successor} that does not exist in model"
                                    )
                                elif action not in model.actions[state]:
                                    print(
                                        f"Warning: Added sample for action {action} that does not exist in state {state}"
                                    )
                                elif successor not in [
                                    t.successor
                                    for t in model.transitions[state][action]
                                ]:
                                    print(
                                        f"Warning: Added sample to successor {successor} that is not reachable from state {state} via action {action}"
                                    )

                if (num_runs % parts["batchsize"]) == 0:
                    strategy.update_values(with_corr_bounds=False, keep_old_bounds=True)
                    initial_states = model.initial_states
                    avg_lower = sum(
                        model.values[init].lower for init in initial_states
                    ) / len(initial_states)
                    avg_upper = sum(
                        model.values[init].upper for init in initial_states
                    ) / len(initial_states)
                    epsilon = avg_upper - avg_lower

                    parts["epsilon"] = epsilon

                    pos = len(setups) - parts["id"]
                    cursor_up = "\033[F"
                    cursor_down = "\033[B"
                    print(
                        f"{pos*cursor_up}\r{name} achieved precision {epsilon}{pos*cursor_down}",
                        end="",
                    )

                    if epsilon <= precision:
                        parts["num_runs"] = num_runs
                        parts["lower"] = avg_lower
                        parts["upper"] = avg_upper
                        parts["sampling_time"] = sampling_time
                        print(
                            f"{pos * cursor_up}\r{name} achieved precision {epsilon} "
                            f"-- done in {parts['num_runs']} sample runs{pos*cursor_down}",
                            end="",
                        )
                        # rerun VI from scratch for timing
                        start_time = time.time()
                        strategy.update_values(
                            with_corr_bounds=False, keep_old_bounds=True
                        )
                        parts["solving_time"] = time.time() - start_time

        if all(x["num_runs"] >= 0 for x in setups.values()):
            break

    num_runs_full = setups["full"]["num_runs"]
    with open(path_to_logfile, "w") as f:
        f.write(
            "name,num_transitions,num_runs,improvement_factor,epsilon,lower,upper,building_time,sampling_time,solving_time,total_time\n"
        )
        for name, parts in setups.items():
            f.write(
                f"{name},{parts['num_transitions']},{parts['num_runs']},{parts['num_runs'] / num_runs_full},"
                f"{parts['epsilon']},{parts['lower']},{parts['upper']},"
                f"{parts['building_time']},{parts['sampling_time']},{parts['solving_time']},"
                f"{parts['building_time']+parts['sampling_time']+parts['solving_time']}\n"
            )

    print()
    print(f"See {path_to_logfile} for logs")
    print("Done!")


def do_property_smc_analysis(
    environment,
    property_name,
    is_min_property,
    path_to_mdp_file,
    path_to_logfile,
    delta,
    precision,
):
    """
    First, collect a number of samples. Then, feed the sample data into the standard SMC method (i.e., additive delta,
    Hoeffding, model as specified) to determine epsilon of property in initial state. Note that MECs (but not their
    attractors) are collapsed here nonetheless to ensure convergence of the value iteration!
    Then, step-wise feed batches into Smart SMC method (i.e., multiplicative delta, Clopper-Pearson, collapsed chains
    and property-specific MECs + attractors) until epsilon obtained by standard method is reached.
    """

    ##########
    # Models #
    ##########
    start_time = time.time()
    model_standard = prepare_model(
        environment,
        path_to_mdp_file,
        delta,
        collapse_mecs=True,
        use_objective_for_mec=False,
        property_name=property_name,
        is_min_property=is_min_property,
    )
    model_standard.delta_distribution = UTMDP.DeltaDistribution.UNIFORM
    model_standard.handle_self_loops()
    model_standard.confidence_method = UTMDP.ConfidenceMethod.HOEFFDING
    model_standard.small_support_improvement = False
    strategy_standard = Strategy.Strategy(model=model_standard)
    strategy_standard.set_policy_uniform()
    print("--- Original model ---")
    print_model_parameters(model_standard)
    print()
    building_time_standard = time.time() - start_time

    start_time = time.time()
    model_smart = prepare_model(
        environment,
        path_to_mdp_file,
        delta,
        collapse_mecs=True,
        use_objective_for_mec=True,
        property_name=property_name,
        is_min_property=is_min_property,
    )
    model_smart.handle_self_loops()
    model_smart.delta_distribution = UTMDP.DeltaDistribution.UNIFORM_PRODUCT
    model_smart.confidence_method = UTMDP.ConfidenceMethod.CLOPPER_PEARSON
    model_smart.small_support_improvement = True
    model_smart.structural_improvements()
    strategy_smart = Strategy.Strategy(model=model_smart)
    strategy_smart.set_policy_uniform()
    print("--- Reduced model ---")
    print_model_parameters(model_smart)
    print()
    building_time_smart = time.time() - start_time

    model_smart.merge_essential_states()

    # print(sorted(model_standard.states, key=lambda x: int(x)))
    # print(sorted(model_smart.states, key=lambda x: int(x)))

    ##################
    # Preparing logs #
    ##################

    logger = logging.getLogger()
    handler = logging.FileHandler(path_to_logfile, mode="w")
    logger.setLevel(logging.INFO)
    logger.addHandler(handler)
    logger.info(
        "method \tnum_sample_runs \tnum_transitions \tlower \tupper \tepsilon"
        "\tmodel_building_time \tsampling_time \tsolving_time \ttotal_time"
    )

    ###################
    # Standard method #
    ###################

    num_transitions = sum(
        len(transitions)
        for state, actions in model_standard.transitions.items()
        for action, transitions in actions.items()
    )
    batchsize = max(1, num_transitions // 3)

    print(f"Trying to obtain precision level {precision}, {batchsize=}...")

    epsilon, avg_lower, avg_upper, sampling_time_standard, num_runs_standard = (
        sample_until_precision(model_standard, strategy_standard, precision, batchsize)
    )

    # start VI from scratch do determine runtime of VI if samples were given by an oracle
    start_time = time.time()
    strategy_standard.update_values(with_corr_bounds=False, keep_old_bounds=False)
    evaluation_time_standard = time.time() - start_time
    sum_time = (
        building_time_standard + sampling_time_standard + evaluation_time_standard
    )

    logger.info(
        f"standard \t{num_runs_standard} \t{num_transitions} \t{avg_lower} \t{avg_upper} \t{epsilon} "
        f"\t{building_time_standard} \t{sampling_time_standard} \t{evaluation_time_standard} \t{sum_time}"
    )

    ########################
    # Smart SMC comparison #
    ########################

    num_prob_transitions_smart = sum(
        (len(transitions) if len(transitions) > 1 else 0)
        for state, actions in model_smart.transitions.items()
        for action, transitions in actions.items()
    )
    batchsize = max(1, num_prob_transitions_smart // 3)
    print(f"Replicating precision level in improved setting, {batchsize=}...")
    num_runs_smart = 0
    sampling_time_smart = 0

    epsilon, avg_lower, avg_upper, sampling_time_smart, num_runs_smart = (
        sample_until_precision(model_smart, strategy_smart, precision, batchsize)
    )

    start_time = time.time()
    strategy_smart.update_values(with_corr_bounds=False, keep_old_bounds=False)
    evaluation_time_smart = time.time() - start_time
    sum_time = building_time_smart + sampling_time_smart + evaluation_time_smart

    logger.info(
        f"smart \t{num_runs_smart} \t{num_prob_transitions_smart} \t{avg_lower} \t{avg_upper} \t{epsilon} "
        f"\t{building_time_smart} \t{sampling_time_smart} \t{evaluation_time_smart} \t{sum_time}"
    )

    print(
        f"Used {num_runs_smart} runs as opposed to {num_runs_standard} runs in standard method "
        f"(factor {num_runs_standard / num_runs_smart})"
    )


def sample_until_precision(model, strategy, precision, batchsize):
    num_transitions = sum(
        (len(transitions))
        for state, actions in model.transitions.items()
        for action, transitions in actions.items()
    )
    num_runs = 0
    sampling_time = 0

    while True:
        new_experience = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
        start_time = time.time()
        for i in range(batchsize):
            num_runs += 1
            max_num_steps = 2 * num_transitions  # alternative: num states
            state = random.choice(model.initial_states)
            steps = 0
            goal_state_reached = False
            while not goal_state_reached and steps < max_num_steps:
                steps += 1
                actions = model.actions[state]
                choice = strategy.choice(state, actions)
                new_state = model.get_sample(state, choice)
                goal_state_reached = new_state in model.sink_states
                new_experience[state][choice][new_state] += 1
                state = new_state
        sampling_time += time.time() - start_time

        ##############
        # Evaluating #
        ##############

        strategy.update_values(with_corr_bounds=False, keep_old_bounds=False)

        initial_states = model.initial_states
        avg_lower = sum(model.values[init].lower for init in initial_states) / len(
            initial_states
        )
        avg_upper = sum(model.values[init].upper for init in initial_states) / len(
            initial_states
        )
        epsilon = avg_upper - avg_lower
        print(f"\r{epsilon=}, interval ({avg_lower}, {avg_upper})", end="")

        if epsilon <= precision:
            break

    print(
        f"\nProperty can be estimated to a precision of {epsilon} (by interval [{avg_lower}, {avg_upper}])"
    )
    return epsilon, avg_lower, avg_upper, sampling_time, num_runs


def print_delta(model):
    for state, actions in model.transitions.items():
        for action, transitions in actions.items():
            if len(transitions) > 1:
                for transition in transitions:
                    d = model.transition_deltas[state][action][transition.successor]
                    print(f"delta per transition is {d}")
                    break
                break
        else:
            continue
        break


def print_model_parameters(model):
    num_states = len(model.states)
    num_actions = sum(len(actions) for state, actions in model.actions.items())
    num_prob_actions = sum(
        len([act for act in actions if len(model.transitions[state][act]) > 1])
        for state, actions in model.actions.items()
    )
    num_transitions = sum(
        (len(transitions))
        for state, actions in model.transitions.items()
        for action, transitions in actions.items()
    )
    num_prob_transitions = sum(
        (len(transitions) if len(transitions) > 1 else 0)
        for state, actions in model.transitions.items()
        for action, transitions in actions.items()
    )
    print(
        f"Model has {num_states} states ({len(model.initial_states)} initial, {len(model.sink_states)} absorbing)"
    )
    print("and " + str(num_actions) + " state-action pairs")
    print("and " + str(num_prob_actions) + " probabilistic state-action pairs")
    print("and " + str(num_transitions) + " transition triples")
    print("and " + str(num_prob_transitions) + " probabilistic transition triples")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "environment",
        help='Either implemented instance of MDP model (e.g. Racetrack.GREY_BARTO_SMALL) or "MDP".'
        'If value is "MDP", the --mdpfile flag must specify the path to an .prism MDP file',
    )
    parser.add_argument(
        "--mdpfile", help="path the the MDP file if MDP environment was chosen"
    )
    parser.add_argument("--delta", help="error tolerance", type=float, default=0.1)
    parser.add_argument("--epsilon", help="precision", type=float, default=0.1)
    parser.add_argument(
        "--epsilon_min",
        help="minimum precision per transition used in theoretical analysis",
        type=float,
        nargs=2,
        default=0.0001,
    )
    parser.add_argument(
        "--epsilon_max",
        help="maximum precision per transition used in theoretical analysis",
        type=float,
        nargs=2,
        default=0.1,
    )
    parser.add_argument("--episodes", help="number of episodes", type=int, default=10)
    parser.add_argument(
        "--logfile", help="path to logfile", type=str, default="learning.dat"
    )
    parser.add_argument(
        "--full",
        help="analyze impact of all improvement separately",
        action="store_true",
    )
    parser.add_argument(
        "--property",
        help="property name in .props file. only used for building MEC quotient w.r.t. "
        "property",
        type=str,
        default="goal",
    )
    parser.add_argument(
        "--minimization", help="minimize reachability in property", action="store_true"
    )
    args = parser.parse_args()

    env = args.environment
    mdp_file = args.mdpfile
    et = args.delta
    eps = args.epsilon
    emin = args.epsilon_min
    emax = args.epsilon_max
    ep = args.episodes
    fl = args.full
    prop = args.property
    minim = args.minimization
    log_path = "logs/logs/" + ("ablation/" if fl else "results/") + args.logfile
    pathlib.Path(log_path).parent.mkdir(parents=True, exist_ok=True)

    do_ablation(
        environment=env,
        property_name=prop,
        is_min_property=minim,
        path_to_mdp_file=mdp_file,
        path_to_logfile=log_path,
        delta=et,
        precision=eps,
        full=fl,
    )
    # do_property_smc_analysis(environment=env,
    #             property_name=prop,
    #             is_min_property=minim,
    #             path_to_mdp_file=mdp_file,
    #             path_to_logfile=log_path,
    #             delta=et,
    #             precision=eps)
    # do_learning(environment=env,
    #             property_name=prop,
    #             is_min_property=minim,
    #             path_to_mdp_file=mdp_file,
    #             path_to_logfile=log_path,
    #             delta=et,
    #             num_episodes=ep)
