import sys
sys.path.append("../../../")

import experiment_plots
from . import rl_constants
from .rl_constants import *
from downward.experiment import FastDownwardExperiment
from .rl_fast_downward_experiment import RLFastDownwardExperiment
from .rl_robustness_experiment import RLRobustnessExperiment

from tools import parsing as apt

from lab.environments import BaselSlurmEnvironment, LocalEnvironment
from lab import tools
from downward.reports.absolute import AbsoluteReport
# from downward.reports import Attribute, geometric_mean


from collections import defaultdict
import glob
import itertools
import json
import os
import platform
import re
import shutil

if sys.version_info < (3,):
    import subprocess32 as subprocess
else:
    import subprocess


KEY_UNEXPLAINED_ERRORS = "unexplained_errors"
PATTERN_TENSORFLOW_WARNING = re.compile(r"(run\.err:\s*)?\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d+: I tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use:( (SSE4.1|SSE4.2|AVX|AVX2|AVX512F|FMA))+\n")
PATTERN_SOFT_LIMIT_LOG = re.compile(r"(driver\.err:\s*)?\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d,\d+ ERROR\s+planner finished and wrote \d+ KiB to run.log \(soft limit: \d+ KiB\)\n")
PATTERN_USING_TENSORFLOW = re.compile(r"Using TensorFlow backend.\n")
PATTERN_HDF5_CONVERSION_WARNING = re.compile(
    r"/infai/ferber/bin/kerascpu/lib/python2\.7/site-packages/h5py/__init__\.py:"
    r"36: FutureWarning: Conversion of the second argument of issubdtype from "
    r"`float` to `np\.floating` is deprecated\. In future, it will be treated as "
    r"`np\.float64 == np\.dtype\(float\)\.type`\.\n  from \._conv import register_"
    r"converters as _register_converters\n", re.MULTILINE)
PATTERN_SAMPLE_FILE_EMPTY_WARNING = re.compile(
    r"../../code-[^/]+/fast-deepcube.py:\d+: UserWarning: loadtxt: Empty input "
    r"file: \"[^\"]+\"\s*delimiter=\";\"\)\n")

#PATTERN_MY_MISTAKE1 = re.compile(
#    r"""Traceback \(most recent call last\):\n  File "../../training_parser.py", line 112, in <module>\n    main\(\)\n  File "../../training_parser.py", line 108, in main\n    parser.parse\(\)\n  File "/infai/ferber/bin/lab/lab/parser.py", line 226, in parse\n    file_parser.apply_functions\(self.props\)\n  File "/infai/ferber/bin/lab/lab/parser.py", line 128, in apply_functions\n    function\(self.content, props\)\n  File "../../training_parser.py", line 44, in find_all_occurrences\n    props\["final_%s" % name] = props\[name]\[-1]\nIndexError:\s*list\s+index out of range\n"""#\n"""
#    ,re.MULTILINE
#)
PATTERNS_ERRORS_TO_IGNORE = [
    PATTERN_TENSORFLOW_WARNING,
    PATTERN_SOFT_LIMIT_LOG,
    PATTERN_USING_TENSORFLOW,
    PATTERN_HDF5_CONVERSION_WARNING,
    PATTERN_SAMPLE_FILE_EMPTY_WARNING,
#    PATTERN_MY_MISTAKE1
]


REGEX_RUNS_DIRECTORY = re.compile(r"runs-\d+-\d+")
REGEX_RUN_DIRECTORY = re.compile(r"\d+")
def filter_some_fetcher_errors(run):

    if KEY_UNEXPLAINED_ERRORS in run:
        for pattern in PATTERNS_ERRORS_TO_IGNORE:
            run[KEY_UNEXPLAINED_ERRORS] = [
            pattern.sub("", ue) for ue in run[KEY_UNEXPLAINED_ERRORS]]
        run[KEY_UNEXPLAINED_ERRORS] = [
            ue for ue in run[KEY_UNEXPLAINED_ERRORS] if
            ue.strip() != "" and
            ue not in ['run.err: ', 'output-to-slurm.err']
        ]
        if len(run[KEY_UNEXPLAINED_ERRORS]) == 0:
            del run[KEY_UNEXPLAINED_ERRORS]
    return run

def fetcher_add_pddl(props):
    if not props["domain"].endswith(".pddl"):
        props["domain"] = props["domain"] + ".pddl"
    return props

def rename_algorithm(props):
    a = props['algorithm']
    approx_value_iteration = a.find("lookahead_2") > -1 or a.find("BASE_AVI") > -1
    search_samples = a.find("search_GBFS") > -1
    lama = a.find("lama") > -1
    ecai = a == "ocls_ns_ubal_h3_sigmoid_inter_gen_sat_drp0_Kall_pruneOff_X_fold_model.pb"
    assert sum([approx_value_iteration, search_samples, lama, ecai]) == 1, a
    is_lazy = a.lower().find("lazy") > -1
    is_learning_expansions = re.match(r".*search_GBFS\d+X-.*", a) is not None
    is_learning_timeout_expansions  = re.match(r".*search_GBFS\d+XX-.*", a) is not None
    loss = re.match(r".*loss_([^-]*)-.*", a)


    if approx_value_iteration:
        props["algorithm"] = "AVI"
    elif search_samples:
        props["algorithm"] = "SS"
    elif lama:
        props["algorithm"] = "LAMA"
    elif ecai:
        props["algorithm"] = "SL"
    else:
        assert False, a
    if is_learning_expansions:
        props["algorithm"] += "+X"
    if is_learning_timeout_expansions:
        props["algorithm"] += "+Timeout"
    if is_lazy:
        props["algorithm"] += "+Lazy"
    if loss is not None:
        props["algorithm"] += "+L=" + loss.group(1)
    return props


EXPERIMENT_DIVERSITY = "DIVERSITY"
EXPERIMENT_DISTANCE = "DISTANCE"
EXPERIMENT_BASELINE = "BASELINE"
EXPERIMENT_NETWORK = "NETWORK"
EXPERIMENT_ROBUSTNESS = "ROBUSTNESS"

def _extend_extra_options(extra_options, new_option):
    return (new_option if extra_options is None else
            "%s\n%s" % (extra_options, new_option))


def _get_environment(cores=1, partition="infai_1", time_limit=None,
                     slurm_dependency=None):
    extra_options = None
    if cores != 1:
        extra_options = _extend_extra_options(
            extra_options, '#SBATCH --cpus-per-task=%i' % cores)
    if time_limit is not None:
        time_limit_s = apt.time(time_limit)
        time_limit_m = int(time_limit_s/60)
        time_limit_s -= time_limit_m * 60
        extra_options = _extend_extra_options(
            extra_options,
            "#SBATCH --time %i:%i" % (time_limit_m, time_limit_s))

    if IS_REMOTE:
        env = BaselSlurmEnvironment(
            email=EMAIL_ADDRESS,
            extra_options=extra_options,
            partition=partition,
            setup="%s\n%s" % (
                BaselSlurmEnvironment.DEFAULT_SETUP,
                "source /infai/ferber/bin/kerascpu/bin/activate\n"),
            previous_job_id=slurm_dependency
        )
    else:
        env = LocalEnvironment(processes=6)
    return env


def _get_experiment(environment, experiment_type):
    if experiment_type == EXPERIMENT_NETWORK:
        return RLFastDownwardExperiment(environment=environment)
    elif experiment_type == EXPERIMENT_ROBUSTNESS:
        return RLRobustnessExperiment(environment=environment)
    elif experiment_type in [EXPERIMENT_DISTANCE, EXPERIMENT_DIVERSITY]:
        return RLFastDownwardExperiment(environment=environment)
    elif experiment_type == EXPERIMENT_BASELINE:
        return FastDownwardExperiment(environment=environment)
    else:
        assert False, experiment_type


def _add_suites(exp, suites, benchmark_repo=BENCHMARK_REPO):
    assert suites is not None
    if suites == "test":
        suites = ["transport-opt08-strips"]

    for suite in suites:
        if isinstance(suite, tuple):
            assert len(suite) == 2
            exp.add_suite(*suite)
        else:
            exp.add_suite(benchmark_repo, suite)


def add_search_parsers(exp):
    exp.add_parser(exp.EXITCODE_PARSER)
    exp.add_parser(exp.TRANSLATOR_PARSER)
    exp.add_parser(exp.SINGLE_SEARCH_PARSER)
    exp.add_parser(exp.PLANNER_PARSER)
    exp.add_parser(rl_constants.PATH_PARSER_LAST_EXPANSION)

def _add_parser(exp, experiment_type):
    if experiment_type == EXPERIMENT_BASELINE:
        add_search_parsers(exp)

    elif experiment_type == EXPERIMENT_NETWORK:
        add_search_parsers(exp)
        exp.add_parser(rl_constants.PATH_PARSER_TRAINING)

    elif experiment_type == EXPERIMENT_ROBUSTNESS:
        add_search_parsers(exp)

    elif experiment_type == EXPERIMENT_DISTANCE:
        exp.add_parser(rl_constants.PATH_PARSER_RW_DISTANCE)

    elif experiment_type == EXPERIMENT_DIVERSITY:
        exp.add_parser(rl_constants.PATH_PARSER_SAMPLING_DIVERSITY)
    else:
        assert False, experiment_type


def copy_training_h_evolution(exp):
    assert os.path.isdir(exp.path)
    path_eval = exp.path + "-eval"
    assert os.path.isdir(path_eval)
    dirs_runs = [os.path.join(exp.path, x) for x in os.listdir(exp.path)
                 if REGEX_RUNS_DIRECTORY.match(x)]
    for dir_runs in dirs_runs:
        dirs_run = [os.path.join(dir_runs, x) for x in os.listdir(dir_runs)
                    if os.path.isdir(os.path.join(dir_runs, x)) and
                    REGEX_RUN_DIRECTORY.match(x)]
        for dir_run in dirs_run:
            src_evolution = os.path.join(dir_run, "evolution.pdf")
            file_static_properties = os.path.join(dir_run, "static-properties")
            if not os.path.isfile(src_evolution):
                continue
            assert os.path.isfile(file_static_properties), file_static_properties

            with open(file_static_properties, "r") as f:
                static_properties = json.load(f)
            domain = static_properties["domain"]
            task = static_properties["problem"]
            assert task.endswith(".pddl"), task
            task = task[:-5]
            trg_evolution = os.path.join(path_eval, "evolution_%s_%s.pdf" % (
                domain, task
            ))
            shutil.copy(src_evolution, trg_evolution)


def _add_evals(exp, experiment_type, add_attributes):
    def add_expansion_plot():
        exp.add_step("plot_expansions",
                     experiment_plots.plot_attribute_evolution,
                     file="plot_expansions.pdf",
                     experiment=exp,
                     attribute=["init_expansions", "inter_expansions",
                                "expansions"],
                     xlabel="consecutive searches",
                     ylabel="expansions"
                     )

    def add_plan_length_plot():
        exp.add_step("plot_plan_length",
                     experiment_plots.plot_attribute_evolution,
                     file="plot_plan_length.pdf",
                     experiment=exp,
                     attribute=["init_plan_length", "inter_plan_length",
                                "plan_length"],
                     xlabel="consecutive searches",
                     ylabel="expansions"
                     )

    if experiment_type == EXPERIMENT_BASELINE:
        exp.add_report(AbsoluteReport(attributes=ATTRIBUTES_SEARCH + add_attributes),
                       name="report")

    elif experiment_type == EXPERIMENT_NETWORK:
        exp.add_report(AbsoluteReport(attributes=(
                ATTRIBUTES_SEARCH + ATTRIBUTES_NETWORK + add_attributes +
                ATTRIBUTES_INIT_SEARCH + ATTRIBUTES_INTER_SEARCHES)),
            name="report")
        add_expansion_plot()
        add_plan_length_plot()
        exp.add_step("Copy_H_Evolution_during_Training",
                     copy_training_h_evolution,
                     exp)

    elif experiment_type == EXPERIMENT_ROBUSTNESS:
        exp.add_report(AbsoluteReport(attributes=ATTRIBUTES_SEARCH + add_attributes),
                       name="report")

    elif experiment_type == EXPERIMENT_DISTANCE:
        exp.add_step("plot_rw_distances", experiment_plots.plot_density,
                     file="plot_rw_%s.pdf",
                     experiment=exp,
                     )

    elif experiment_type == EXPERIMENT_DIVERSITY:
        pass
    else:
        assert False, experiment_type


def call_other_experiment(path):
    subprocess.call([sys.executable, path, "--all"])


def add_step_check_robustness(
        exp, step_name="check_robustness",
        robustness_suffix="", benchmark_directory=None,
        robustness_mutex_options=None,
        overall_time_limit="30m",
        robustness_partition="infai_1",
        robustness_network_predefinitions=None):
    assert benchmark_directory is not None
    assert robustness_network_predefinitions is not None
    if robustness_mutex_options is None:
        robustness_mutex_options = []
    path_experiment = sys.argv[0]
    path_robustness = os.path.splitext(path_experiment)
    path_robustness = "%s-robustness%s%s" % (
        path_robustness[0], robustness_suffix, path_robustness[1])
    if not os.path.exists(path_robustness):
        dir_exp, file_exp = os.path.split(path_experiment)
        dir_exp = os.path.normpath(os.path.abspath(os.path.join(dir_exp, "data", os.path.splitext(file_exp)[0])))

        with open(rl_constants.PATH_TEMPLATE_EXP_ROBUSTNESS, "r") as f:
            template = f.read()
            template = template.format(
                BENCHMARK_DIRECTORY=benchmark_directory,
                EXPERIMENT_DIRECTORY='["%s"]' % dir_exp,
                CONFIGURATION_NAME=os.path.basename(
                    os.path.splitext(path_experiment)[0]),
                OVERALL_TIME_LIMIT=overall_time_limit,
                MUTEX_OPTIONS=str(robustness_mutex_options),
                PARTITION=robustness_partition,
                NETWORK_PREDEFINITIONS=robustness_network_predefinitions,
                NEXT_EXPERIMENT="None",
            )
        with open(path_robustness, "w") as f:
            f.write(template)

    exp.add_step(step_name + robustness_suffix, call_other_experiment,
                 path_robustness)


def get_base_experiment(suites=None, cores=1, partition="infai_1",
                        experiment_type=EXPERIMENT_NETWORK, time_limit=None,
                        benchmark_repo=BENCHMARK_REPO, add_attributes=None,
                        robustness_benchmarks=None,
                        robustness_time_limit="30m",
                        robustness_suffix="",
                        robustness_mutex_options=None,
                        robustness_partition=None,
                        robustness_network_predefinitions=None,
                        slurm_dependency=None,
                        next_experiment=None):
    assert experiment_type in [EXPERIMENT_BASELINE, EXPERIMENT_NETWORK,
                               EXPERIMENT_DISTANCE, EXPERIMENT_DIVERSITY,
                               EXPERIMENT_ROBUSTNESS]
    robustness_partition = (partition if robustness_partition is None
                            else robustness_partition)
    add_attributes = [] if add_attributes is None else add_attributes
    if next_experiment is None:
        next_experiment = []
    elif isinstance(next_experiment, str):
        next_experiment = [next_experiment]

    env = _get_environment(cores, partition, time_limit, slurm_dependency)

    exp = _get_experiment(env, experiment_type)

    _add_suites(exp, suites, benchmark_repo=benchmark_repo)

    exp.add_step("build", exp.build)
    exp.add_step("start", exp.start_runs)
    some_fetcher_filters = [filter_some_fetcher_errors]
    exp.add_fetcher(name="fetch", filter=some_fetcher_filters)

    _add_parser(exp, experiment_type)

    _add_evals(exp, experiment_type, add_attributes)

    if experiment_type == EXPERIMENT_NETWORK and robustness_benchmarks is not None:
        add_step_check_robustness(
            exp, benchmark_directory=robustness_benchmarks,
            overall_time_limit=robustness_time_limit,
            robustness_suffix=robustness_suffix,
            robustness_mutex_options=robustness_mutex_options,
            robustness_partition=robustness_partition,
            robustness_network_predefinitions=robustness_network_predefinitions,
        )
    for next_exp_script in next_experiment:
        exp.add_step("LAUNCH_{}".format(next_exp_script), call_other_experiment,
                     next_exp_script)

    return exp


def add_algorithm(exp, name, network,
                  component_options=None, build_options=None,
                  repository=REPO, revision="DeePDown", *args, **kwargs):
    """
    If changing the build version, also provide the build version  argument
     for the component
    :param exp:
    :param name:
    :param network:
    :param component_options:
    :param build_options:
    :param repository:
    :param revision:
    :return:
    """
    component_options = [] if component_options is None else component_options
    build_options = ([] if build_options is None else build_options)

    # Validate usage of the same build configuration
    cbuild = [component_options[n + 1]
              for n, a in enumerate(component_options)
              if a == "--fast-downward-build"]
    assert len(cbuild) == 1, \
        "specify exactly once the build to use in fast-deepcube.py"
    bbuild = [x for x in build_options if REGEX_BUILD_NAME.match(x)]
    assert len(bbuild) == 1, \
        "specify exactly once the build options for building DeePDown"
    assert cbuild == bbuild, "%s, %s" % (str(cbuild), str(bbuild))
    exp.add_algorithm(name, repository, revision, network,
                      component_options, build_options, *args, **kwargs)
