#! /usr/bin/env python
import math
try:
    from pathlib import Path
except ImportError:
    from pathlib2 import Path
import os
import sys

import helper
from helper import common, supervised_experiment
from helper import network_parameters as nparams
from helper.network_parameters import Parameters as NParams

from downward.reports.absolute import AbsoluteReport
from lab.environments import LocalEnvironment, BaselSlurmEnvironment


# Server Settings
MAIL = "patrick.ferber@unibas.ch"
PARTITION = "infai_1"
OVERALL_TIME_LIMIT = "600m"


REPO = os.environ["NDOWNWARD"]
BENCHMARK_DIRECTORY = Path("~/repositories/benchmarks_ecai").expanduser()
NETWORK_DIRECTORY = Path("~/repositories/benchmarks_ecai").expanduser()
EXTRA_OPTIONS = None

PREFIXES = ["ocls_ns_ubal_h3_sigmoid_inter_gen_sat_drp0_Kall_pruneOff_"]
UNARY_THRESHOLD = [0.01]
BIN_SIZE = 1
# List of integers OR a dict of "suffix for the configuration": "int or callable"
# FOLDS = [i for i in range(10)]
TASK2FOLD = {'blocks/prob-B-18-4': 0, 'blocks/prob-B-20-1': 0, 'blocks/prob-B-25-2': 0, 'blocks/prob-B-30-1': 0, 'blocks/probBLOCKS-14-0': 0, 'blocks/probBLOCKS-15-0': 0, 'blocks/probBLOCKS-15-1': 9, 'blocks/probBLOCKS-16-1': 0, 'blocks/probBLOCKS-17-0': 7, 'depot/p05': 0, 'depot/p06': 6, 'depot/p08': 0, 'depot/p09': 1, 'depot/p11': 0, 'depot/p12': 1, 'depot/p14': 0, 'depot/p15': 2, 'depot/p16': 0, 'depot/p18': 0, 'depot/p19': 1, 'grid/prob03': 9, 'grid/prob04': 6, 'grid/prob04_2': 4, 'grid/prob04_3': 9, 'grid/prob04_4': 0, 'grid/prob05': 1, 'grid/prob05_1': 3, 'grid/prob05_2': 8, 'grid/prob05_3': 7, 'grid/prob05_4': 1, 'npuzzle/prob_n6_1': 0, 'npuzzle/prob_n6_2': 0, 'npuzzle/prob_n6_3': 0, 'npuzzle/prob_n6_4': 0, 'npuzzle/prob_n7_1': 0, 'npuzzle/prob_n7_2': 0, 'npuzzle/prob_n7_3': 0, 'npuzzle/prob_n7_4': 0, 'pipesworld-notankage/p19-net2-b18-g6': 0, 'pipesworld-notankage/p21-net3-b12-g2': 0, 'pipesworld-notankage/p22-net3-b12-g4': 0, 'pipesworld-notankage/p24-net3-b14-g5': 0, 'pipesworld-notankage/p25-net3-b16-g5': 1, 'pipesworld-notankage/p26-net3-b16-g7': 1, 'pipesworld-notankage/p27-net3-b18-g6': 0, 'pipesworld-notankage/p28-net3-b18-g7': 8, 'pipesworld-notankage/p29-net3-b20-g6': 7, 'pipesworld-notankage/p30-net3-b20-g8': 4, 'pipesworld-notankage/p31-net4-b14-g3': 2, 'pipesworld-notankage/p32-net4-b14-g5': 1, 'pipesworld-notankage/p33-net4-b16-g5': 2, 'pipesworld-notankage/p34-net4-b16-g6': 2, 'pipesworld-notankage/p35-net4-b18-g4': 1, 'pipesworld-notankage/p36-net4-b18-g6': 1, 'pipesworld-notankage/p37-net4-b20-g5': 5, 'pipesworld-notankage/p38-net4-b20-g7': 3, 'pipesworld-notankage/p39-net4-b22-g7': 5, 'pipesworld-notankage/p40-net4-b22-g8': 4, 'pipesworld-notankage/p41-net5-b22-g2': 8, 'pipesworld-notankage/p45-net5-b26-g4': 1, 'pipesworld-notankage/p49-net5-b30-g6': 4, 'rovers/p11': 0, 'rovers/p18': 0, 'rovers/p19': 0, 'rovers/p20': 0, 'rovers/p21': 3, 'rovers/p22': 0, 'rovers/p23': 0, 'rovers/p24': 3, 'rovers/p26': 5, 'rovers/p27': 0, 'rovers/p28': 0, 'rovers/p29': 0, 'rovers/p30': 0, 'rovers/p31': 0, 'rovers/p32': 0, 'rovers/p33': 0, 'rovers/p34': 0, 'rovers/p35': 0, 'rovers/p36': 0, 'rovers/p37': 0, 'rovers/p38': 0, 'rovers/p39': 0, 'rovers/p40': 0, 'scanalyzer-opt11-strips/p07': 0, 'scanalyzer-opt11-strips/p10': 0, 'scanalyzer-opt11-strips/p13': 0, 'scanalyzer-opt11-strips/p15': 0, 'scanalyzer-opt11-strips/p16': 1, 'scanalyzer-opt11-strips/p17': 0, 'scanalyzer-opt11-strips/p18': 3, 'scanalyzer-opt11-strips/p19': 1, 'storage/p18': 1, 'transport-opt14-strips/p10': 0, 'transport-opt14-strips/p11': 0, 'transport-opt14-strips/p12': 0, 'transport-opt14-strips/p16': 0, 'transport-opt14-strips/p17': 0, 'transport-opt14-strips/p18': 0, 'transport-opt14-strips/p19': 4, 'transport-opt14-strips/p20': 2, 'transport-opt14-strips/p31': 0, 'transport-opt14-strips/p32': 0, 'transport-opt14-strips/p33': 0, 'visitall-opt14-strips/p-1-12': 0, 'visitall-opt14-strips/p-1-13': 0, 'visitall-opt14-strips/p-1-14': 0, 'visitall-opt14-strips/p-1-15': 0, 'visitall-opt14-strips/p-1-16': 0, 'visitall-opt14-strips/p-1-17': 0, 'visitall-opt14-strips/p-1-18': 0}
# for dss in ["storage/p19", "storage/p20", "storage/p21", "storage/p23"]:
#     TASK2FOLD[dss] = None
FOLDS = {"best_per_state_space": supervised_experiment.assign_domains_2_folds(TASK2FOLD)}


ATTRIBUTES = ['coverage', 'total_time', 'expansions', 'error', 'cost',
              'plan_length']

""" --- Everything below this line is automatically set --- """

# Get Lab Environment
if helper.IS_REMOTE:
    ENV = BaselSlurmEnvironment(
        email=MAIL,
        extra_options=EXTRA_OPTIONS,
        partition=PARTITION,
    )
else:
    ENV = LocalEnvironment(processes=1)

exp = supervised_experiment.SupervisedDownwardExperiment(environment=ENV)

for benchmark_suite, tasks in common.find_benchmark_suites(BENCHMARK_DIRECTORY).items():
    exp.add_suite(str(benchmark_suite), tasks)

exp.add_step('build', exp.build)
exp.add_step('start', exp.start_runs)
exp.add_fetcher(name='fetch', filter=[supervised_experiment.filter_some_fetcher_errors])
exp.add_report(
    AbsoluteReport(attributes=ATTRIBUTES),
    outfile='report.html')

exp.add_parser(exp.EXITCODE_PARSER)
exp.add_parser(exp.TRANSLATOR_PARSER)
exp.add_parser(exp.SINGLE_SEARCH_PARSER)
exp.add_parser(exp.PLANNER_PARSER)


for prefix in PREFIXES:
    config = nparams.parse_parameters_from_prefix(prefix, "model.pb")
    network_type = config[NParams.NETWORK_TYPE]
    assert not (network_type == nparams.NetworkTypes.CLASSIFICATION
                and (len(UNARY_THRESHOLD) != 1 or UNARY_THRESHOLD[0] != 0))

    config_base_name = config[NParams.PREFIX].rstrip("_")
    for unary_threshold in UNARY_THRESHOLD:
        config_name = "{}{}".format(
            config_base_name,
            "" if unary_threshold == 0 else ("_unary_threshold_%%.%if" % abs(
                math.log10(unary_threshold))) % unary_threshold)

        if network_type == nparams.NetworkTypes.ORDINAL_CLASSIFICATION:
            fd_network_type_name = "classification"
        else:
            fd_network_type_name = network_type.name.lower()

        network_definition = common.TEMPLATE_NETWORK.format(
            TYPE=fd_network_type_name,
            STATE_LAYER=config[NParams.INPUT_STATE],
            GOAL_LAYER=config[NParams.INPUT_GOAL],
            OUTPUT_LAYER=config[NParams.OUTPUT_LAYER],
            ATOMS="{" + config[NParams.ATOMS] + "}",
            VALUES="{" + config[NParams.INITS] + "}",
            UNARY_THRESHOLD=unary_threshold,
            BIN_SIZE=BIN_SIZE,
            MODEL_FILE="model.pb")

        heuristic_definition = common.TEMPLATE_NETWORK_HEURISTIC.format(
            BLIND="false", NETWORK="net", CONFIDENCE_VERBOSITY="-1")

        search_definition = common.TEMPLATE_ALGORITHM_EAGER_GREEDY.format(
            HEURISTICS=heuristic_definition, MISC="")

        component_options = [
            "--network", "{}={}".format("net", network_definition),
            "--search", search_definition

        ]
        build_options = ["release64dynamic"]
        driver_options = [
            "--build", "release64dynamic", "--overall-time-limit", OVERALL_TIME_LIMIT]

        if isinstance(FOLDS, list):
            FOLDS = {f: f for f in FOLDS}
        for fold_name, fold_value in FOLDS.items():
            exp.add_algorithm(
                config_name + "_fold_" + str(fold_name), REPO, "DeePDown",
                network_directory=NETWORK_DIRECTORY,
                model_template=config_base_name + "_{FOLD}_fold_model.pb",
                fold=fold_value,
                component_options=component_options,
                build_options=build_options,
                driver_options=driver_options,
            )

exp.run_steps()
