import sys
import os
sys.path.append(os.path.join(os.environ["NDOWNWARD"]))
from tools import parsing as apt

from downward.reports.absolute import AbsoluteReport

import enum
from . import rl_constants
from . import rl_experiment_parsing
from . import rl_experiment
from . import rl_ecai_experiment
from . import rl_base_configurations

class Suites(enum.Enum):
    Test = ["p16.pddl:uniform-valid_True-0.pddl",
           "p16.pddl:uniform-valid_True-5.pddl",
           "p16.pddl:uniform-valid_True-6.pddl"]
    Test2 = [
        "p16.pddl-200:p1.pddl",
        "p16.pddl-200:p2.pddl",
        "p16.pddl-200:p2.pddl"
    ]
    P16ToP30 = ['p{}.pddl'.format(i) for i in range(16, 31)]


TEMPLATE_REGRESSION_SAS_STATE_NETWORK = [
        '--network',
        'hnn=snet(type=regression,unary_threshold=0,state_layer=input_1,'
        'path=model.pb,bin_size=1,output_layers={{MODEL_OUTPUT_LAYER,'
        'model.pb}},allow_unused_atoms=false,atoms={network_input_atoms}, '
        'defaults={network_input_defaults},'
        'exponentiate_heuristic={exponentiate_heuristic},'
        'domain_max_is_undefined=false)']

class Predefinitions(enum.Enum):
    Regression_SAS_State_Network = [
        x.replace("{exponentiate_heuristic}", "false")
        for x in TEMPLATE_REGRESSION_SAS_STATE_NETWORK]
    Regression_SAS_State_Network_Expansions = [
        x.replace("{exponentiate_heuristic}", "true")
        for x in TEMPLATE_REGRESSION_SAS_STATE_NETWORK]

class Search(enum.Enum):
    EagerGreedy = 'eager_greedy([nh(hnn)],cost_type=one)'
    LazyGreedy = 'lazy_greedy([nh(hnn)],cost_type=one)'
    EagerEpsilon = 'eager(epsilon_greedy(nh(hnn),epsilon={epsilon}),cost_type=one)'


BUILD = "%s64dynamic" % ("release" if rl_constants.IS_REMOTE else "debug")


class RobustnessExperimentFactory(object):
    @staticmethod
    def create_robustness_experiment(
            configuration_name,
            source_experiment,
            benchmark_repo,
            suites,
            rln_predefinition,
            search,
            build=BUILD,
            overall_time_limit="30m",
            mutex_options=None,
            epsilon=0.0,
            partition="infai_1",
            max_nb_models=None,
            skip_first_model=False,
            sort_by_coverage_ratio=None,
            slurm_dependency=None,
            next_experiment=None
    ):
        exp = rl_experiment.get_base_experiment(
            experiment_type=rl_experiment.EXPERIMENT_ROBUSTNESS,
            suites=suites, cores=1,
            benchmark_repo=benchmark_repo,
            partition=partition,
            slurm_dependency=slurm_dependency,
            next_experiment=next_experiment
        )
        search = search.format(
            epsilon=epsilon
        )
        COMPONENT_OPTIONS = rln_predefinition + ['--search'] + [search]
        DRIVER_OPTIONS = ["--overall-time-limit", overall_time_limit,
                          "--build", build]
        if mutex_options is not None:
            DRIVER_OPTIONS += mutex_options

        exp.add_algorithm(configuration_name, rl_constants.REPO,
                          "DeePDown",
                          source_experiment, COMPONENT_OPTIONS, [build],
                          DRIVER_OPTIONS,
                          max_nb_models=max_nb_models,
                          skip_first_model=skip_first_model,
                          sort_by_coverage_ratio=sort_by_coverage_ratio,)
        return exp


class RLExperimentFactory(object):
    @staticmethod
    def _get_sampling_techniques_prior_2020_04_09(filename):
        SAMPLING_TECHNIQUES = rl_experiment_parsing.get_sampling_technique_prior_2020_04_09(filename)
        SAMPLING_TECHNIQUES_DISTRIBUTION = \
            rl_experiment_parsing.get_sampling_technique_distribustion(
                filename)
        SCRAMBLE = rl_experiment_parsing.get_scrambling(filename)
        BIAS = rl_experiment_parsing.get_bias(filename)
        assert len(SAMPLING_TECHNIQUES_DISTRIBUTION) == 0 or len(
            SAMPLING_TECHNIQUES) == len(SAMPLING_TECHNIQUES_DISTRIBUTION)
        assert len(SCRAMBLE) in [0, 1] or len(SCRAMBLE) == len(
            SAMPLING_TECHNIQUES)
        assert len(BIAS) in [0, 3]

        arguments = []
        for no, st in enumerate(SAMPLING_TECHNIQUES):
            arguments += ["--sampling-technique", st]
            if len(SAMPLING_TECHNIQUES_DISTRIBUTION) > 0:
                argumens += ["++weight", SAMPLING_TECHNIQUES_DISTRIBUTION[no]]
            if len(SCRAMBLE) > 0:
                arguments += ["++max-scrambles", SCRAMBLE[0] if len(SCRAMBLE) == 1 else SCRAMBLE[no]]
            if len(BIAS) > 0:
                arguments += ["++bias"] + BIAS
        return arguments


    @staticmethod
    def _get_sampling_techniques_since_2020_04_09(filename):
        assert len(rl_experiment_parsing.get_bias(filename)) == 0, "Deprecated"
        assert len(rl_experiment_parsing.get_sampling_technique_distribustion(filename)) == 0, "Deprecated"
        assert len(rl_experiment_parsing.get_scrambling(filename)) == 0, "Deprecated"
        sampling_techniques = rl_experiment_parsing.get_sampling_technique_since_2020_04_09(filename)
        arguments = []
        for st in sampling_techniques:
            technique_name = st.split("+")[0].strip()
            assert technique_name != ""
            assert technique_name in rl_experiment_parsing.SAMPLING_TECHNIQUE_ABBREVIATIONS
            arguments += ["--sampling-technique",
                          rl_experiment_parsing.SAMPLING_TECHNIQUE_ABBREVIATIONS[technique_name]]

            weight = rl_experiment_parsing.get_sampling_technique_weight(st)
            if weight is not None:
                arguments += ["++weight", weight]

            upgrade = rl_experiment_parsing.get_sampling_technique_upgrade(st)
            if upgrade is not None:
                arguments.extend(["++upgrade"] + upgrade)

            bias = rl_experiment_parsing.get_bias(st, anker="B_", split_char="+")
            if len(bias) != 0:
                arguments += ["++bias"] + bias

            scrambles_and_increment = rl_experiment_parsing.get_sampling_technique_scrambles(st)
            if scrambles_and_increment is not None:
                arguments += scrambles_and_increment

            active = rl_experiment_parsing.get_sampling_technique_active(st)
            if active is not None:
                arguments += ["++active", active]
        return arguments

    @staticmethod
    def create_experiment_from_filename(
            filename, previous_robustness=None, add_robustness=True,
            slurm_dependency=None,
            next_experiment=None,
    ):
        filename = rl_base_configurations.replace_base(filename)
        algorithm_name = filename[11:]
        date = tuple([int(x) for x in filename[:10].split("-")])

        build = "%s64dynamic" % (
            "release" if rl_constants.IS_REMOTE else "debug")
        MAX_TRAINING_TIME = rl_experiment_parsing.get_time(filename)
        TIME_SEARCH = 1800
        TIME_OUT_BUFFER = 1800
        TOTAL_TIME_LIMIT = MAX_TRAINING_TIME + TIME_SEARCH + TIME_OUT_BUFFER

        LOOKAHEAD = rl_experiment_parsing.get_lookahead(filename)
        SEARCH_ENGINE = rl_experiment_parsing.get_sampling_engine(filename)
        assert len(SEARCH_ENGINE) in [0, 2]
        PREDICT_EXPANSIONS = len(SEARCH_ENGINE) == 2 and SEARCH_ENGINE[1].endswith("X")
        if PREDICT_EXPANSIONS:
            SEARCH_ENGINE += ["--transform-label", "ln",
                              "exponentiate_heuristic:true"]
        REPLAY = rl_experiment_parsing.get_replay(filename)

        INSTANCIATE_PARTIALS = rl_experiment_parsing.get_partial(filename)
        L2 = rl_experiment_parsing.get_l2(filename)
        OPTIMIZER = rl_experiment_parsing.get_opt(filename)
        LOSS = rl_experiment_parsing.get_loss(filename)
        DROPOUT = rl_experiment_parsing.get_dropout(filename)
        LR_DECAY = rl_experiment_parsing.get_lr_decay(filename)
        BUFFER_FACTOR = rl_experiment_parsing.get_buffer_factor(filename)

        MUTEX_OPTIONS = rl_experiment_parsing.get_mutex_options(filename)
        assert filename.find("pre_h2") == -1
        EXPAND_GOAL = rl_experiment_parsing.get_expand_goal(filename)

        TRAINING_REPETITIONS = rl_experiment_parsing.get_training_repetitions(filename)


        # NETWORK = rl_experiment_parsing.get_network(filename, "min_time=75").format(
        NETWORK = rl_experiment_parsing.get_network(filename, lr_decay=LR_DECAY).format(
            replay=REPLAY,
            l2=L2,
            optimizer=OPTIMIZER,
            dropout=DROPOUT,
            loss=LOSS
        )

        if date < (2020, 4, 9):
            SAMPLING_TECHNIQUES = RLExperimentFactory._get_sampling_techniques_prior_2020_04_09(filename)
        else:
            SAMPLING_TECHNIQUES = RLExperimentFactory._get_sampling_techniques_since_2020_04_09(filename)

        # "--increase-scrambling", "keras_condition_counter(threshold=0.05)", "5", "20",
        COMPONENT_OPTIONS = (
                [
                    "--fast-downward-build", build,
                    "--maximum-training-time", "%is" % MAX_TRAINING_TIME,
                    "--add-final-evaluation",
                    "--reinitialize-after-time", "0.75h",
                    "--working-directory", "$TMPDIR",
                ] + LOOKAHEAD + BUFFER_FACTOR + MUTEX_OPTIONS
                + SAMPLING_TECHNIQUES + EXPAND_GOAL
                + SEARCH_ENGINE)

        # init => perform initial blind search
        # inter => perform intermediate searches with current NN
        # Uppercase => stop training if solution was found on search, else continue, but
        # do not run anymore intermediate/initial searches and store log in the given
        # filename name
        if filename.find("-init-") > -1:
            COMPONENT_OPTIONS += [
                "--add-initial-evaluation", "default", 300]
        elif filename.find("-INIT-") > -1:
            COMPONENT_OPTIONS += [
                "--add-initial-evaluation", "default", 300, "init.log"]
        if filename.find("-inter-") > -1:
            COMPONENT_OPTIONS += [
                "--add-intermediate-evaluations", "default", "600", "180"]
        elif filename.find("-INTER-") > -1:
            COMPONENT_OPTIONS += [
                "--add-intermediate-evaluations", "default", "600", "180",
                "inter.zip"]

        if INSTANCIATE_PARTIALS > 0:
            COMPONENT_OPTIONS += ["--wrap-partial-assignment",
                                  INSTANCIATE_PARTIALS]

        ARG_SUITE = os.path.splitext(filename.split("-")[-1])[0]
        SUITES = rl_experiment_parsing.get_tasks(ARG_SUITE)

        if add_robustness:
            if ARG_SUITE.find("ecai") > -1 or ARG_SUITE.find("further1") > -1:
                ROBUSTNESS_BENCHMARKS = rl_constants.DIR_BENCHMARKS_ECAI
            elif ARG_SUITE.find("regr") > -1:
                ROBUSTNESS_BENCHMARKS = rl_constants.DIR_BENCHMARKS_REGRESSION
            else:
                ROBUSTNESS_BENCHMARKS = rl_constants.DIR_BENCHMARKS_ECAI  #DIR_BENCHMARKS_ALT_INIT
        else:
            ROBUSTNESS_BENCHMARKS = None

        exp = rl_experiment.get_base_experiment(
            suites=SUITES,
            cores=4,
            time_limit="%is" % TOTAL_TIME_LIMIT,
            robustness_benchmarks=ROBUSTNESS_BENCHMARKS,
            robustness_time_limit="600m",
            robustness_suffix="-600m",
            robustness_mutex_options=MUTEX_OPTIONS,
            robustness_partition="infai_1",
            robustness_network_predefinitions=(
                "rl_experiment_factory.Predefinitions.Regression_SAS_State_Network.value"
                if not PREDICT_EXPANSIONS else
                "rl_experiment_factory.Predefinitions.Regression_SAS_State_Network_Expansions.value"
            ),
            slurm_dependency=slurm_dependency,
            next_experiment=next_experiment,
        )
        rl_experiment.add_algorithm(
            exp, algorithm_name, NETWORK,
            build_options=[build],
            component_options=COMPONENT_OPTIONS,
            repetitions=TRAINING_REPETITIONS,
            filter_previous_robustness=previous_robustness
        )

        print(COMPONENT_OPTIONS)
        print(NETWORK)
        print(SUITES)
        print(len(SUITES))

        return exp

    @staticmethod
    def create_experiment(
            algorithm_name,
            sampling_techniques,
            network,  # small
            experience_replay,  # replay
            optimizer,  # adam
            lr_decay,  # ""
            l2,  # 0
            dropout,  # 0
            buffer_size_factor,  # None or > 0

            training_time,  # 28h
            reinitialize_time,  # None or 0.75h
            search_time,  # 1800
            buffer_time,  # 1800

            search_engine,  # None or V or ASTAR10
            lookahead,  # None or 2
            mutexes,  # None or translator or h2
            partial_instantiation,  # None or > 0
            goal_expansions,  # None or look it up

            add_final_evaluation,  # True
            add_initial_evaluation,  # None or ["default", 300, "init.log"]
            add_intermediate_evaluation,  # None or ["default", "600", "180", "inter.zip"]
            working_directory,  # None or $TMPDIR
    ):
        assert False, "SUITE selection missing"
        build = "%s64dynamic" % ("release" if rl_constants.IS_REMOTE
                                 else "debug")
        # Create network options
        assert optimizer in ["adam", "sgd", "sgd(momentum=1.0)"]
        assert l2 is None or l2 >= 0
        assert dropout is None or 0 <= dropout <= 1
        lr_decay = rl_experiment_parsing.get_lr_decay(lr_decay)
        network = rl_experiment_parsing.get_network_attribute(network, lr_decay=lr_decay).format(
            replay=experience_replay,
            l2=l2,
            optimizer=optimizer,
            dropout=dropout
        )

        # Calculate time limit
        training_time = apt.time(training_time)
        search_time = apt.time(search_time)
        buffer_time = apt.time(buffer_time)
        total_time_limit = training_time + search_time + buffer_time

        # Options for fast-deepcube.py
        component_options = ["--fast-downward-build", build,
                             "--maximum-training-time", training_time]
        if add_final_evaluation:
            component_options += ["--add-final-evaluation"]
        if add_initial_evaluation is not None:
            component_options += ["--add-initial-evaluation"] + add_initial_evaluation
        if add_intermediate_evaluation is not None:
            component_options += ["--add-intermediate-evaluations"] + add_intermediate_evaluation
        if reinitialize_time is not None:
            reinitialize_time = apt.time(reinitialize_time)
            component_options += ["--reinitialize-after-time", reinitialize_time]
        if working_directory is not None:
            if not os.path.isdir(working_directory) and not working_directory in os.environ:
                print("Warning> Working directory does not exist on this machine")
            component_options += ["--working-directory", working_directory]


        if lookahead is not None:
            lookahead = apt.int_zero_positive(lookahead)
            component_options += ["--lookahead", lookahead]

        if search_engine is not None:
            component_options += ["--sampling-engine", search_engine]

        for st in sampling_techniques:
            assert len(st) >= 1
            component_options += ["--sampling-technique"] + st
        if mutexes is not None:
            component_options += rl_experiment_parsing.get_mutex_options_arguments(mutexes)

        if partial_instantiation is not None:
            assert partial_instantiation > 0
            component_options += ["--wrap-partial-assignment", partial_instantiation]

        if buffer_size_factor is not None:
            assert buffer_size_factor > 0
            component_options += ["--buffer-size-factor", buffer_size_factor]

        if goal_expansions is not None:
            assert len(goal_expansions) == 3
            assert goal_expansions[0] >= 0
            assert goal_expansions[1] in ["true", "false"]
            assert goal_expansions[2] == -1 or goal_expansions[2] > 0
            component_options += ["--expand-goal"] + goal_expansions


        SUITES = [rl_experiment_parsing.get_domain_suite(domain)[15:]
                  for domain in [rl_constants.DOMAIN_STORAGE]]


        exp = rl_experiment.get_base_experiment(
            suites=SUITES,
            cores=4,
            time_limit="%is" % total_time_limit,
        )
        rl_experiment.add_algorithm(
            exp, algorithm_name, network,
            build_options=[build],
            component_options=component_options)

        print("COMPONENT OPTIONS", component_options)
        return exp


TEMPLATE_ECAI_NETWORK = (
    'sgnet(path={MODEL_FILE},type={TYPE},bin_size={BIN_SIZE},'
    'unary_threshold={UNARY_THRESHOLD},state_layer={STATE_LAYER},'
    'goal_layer={GOAL_LAYER},output_layers={OUTPUT_LAYER},'
    'atoms={ATOMS},defaults={VALUES})')
TEMPLATE_ECAI_SEARCH = 'eager_greedy([nh(blind=false,network=hnn)])'


class EcaiExperimentFactory(object):
    @staticmethod
    def create_ecai_experiment(
            search,
            model_template,
            model_template_values,
            model_directory,
            benchmark_repo,
            suites,
            build=BUILD,
            overall_time_limit="30m",
            partition="infai_1",
            cores=1,
            unary_threshold=-1,
            bin_size=1,
    ):
        print(">>>>>>>>>>>>A")
        env = rl_experiment._get_environment(
            cores=cores, partition=partition)
        print(">>>>>>>>>>>>B")
        exp = rl_ecai_experiment.RLEcaiExperiment(environment=env)
        print(">>>>>>>>>>>>C")
        rl_experiment._add_suites(exp, suites, benchmark_repo=benchmark_repo)
        print(">>>>>>>>>>>>D")
        exp.add_step("build", exp.build)
        exp.add_step("start", exp.start_runs)
        some_fetcher_filters = [rl_experiment.filter_some_fetcher_errors]
        exp.add_fetcher(name="fetch", filter=some_fetcher_filters)

        print("AAAAAAAA")
        rl_experiment.add_search_parsers(exp)
        print("BBBBBBBB")
        exp.add_report(
            AbsoluteReport(attributes=rl_constants.ATTRIBUTES_SEARCH),
            name="report")



        DRIVER_OPTIONS = ["--overall-time-limit", overall_time_limit,
                          "--build", build]
        assert any(model_template.startswith(x) for x in ["cls_", "ocls_", "reg_"])
        assert any(model_template.find(x) > -1 for x in ["_ns_", "_full_"])

        COMPONENT_OPTIONS = ["--network", "hnn=" + TEMPLATE_ECAI_NETWORK.format(
            BLIND="false",
            TYPE=("regression" if model_template.startswith("reg")
                  else "classification"),
            STATE_LAYER="input_1_1",
            GOAL_LAYER="input_2_1",
            OUTPUT_LAYER="{{MODEL_OUTPUT_LAYER,model.pb}}",
            ATOMS="{PDDL_ATOMS%s}" % ("_FLEXIBLE" if model_template.find("_ns_") > -1 else ""),
            VALUES="{PDDL_INITS%s}" % ("_FLEXIBLE" if model_template.find("_ns_") > -1 else ""),
            UNARY_THRESHOLD=unary_threshold,
            BIN_SIZE=bin_size,
            MODEL_FILE="model.pb"),
                             "--search", search]

        exp.add_algorithm(
            model_template.format("X"),
            rl_constants.REPO,
            "DeePDown",
            model_directory, model_template, model_template_values,
            COMPONENT_OPTIONS, [build],
            DRIVER_OPTIONS)
        return exp
