#!/usr/bin/env python
from __future__ import print_function

import argparse
import os
import re
import shutil
import subprocess
import sys

from srce import rl_constants
from srce import rl_experiment_parsing

if sys.version_info < (3,):
    input = raw_input


parser = argparse.ArgumentParser()
mxg_create = parser.add_mutually_exclusive_group()
mxg_create.add_argument("--create", action="store_true",
                        help="Create all experiment scripts. If some already"
                             "exist, then the user is asked what to do.")
mxg_create.add_argument("--recreate", action="store_true",
                        help="Create all experiment scripts and overwrites old"
                             "scripts with the same name.")
parser.add_argument("--start", type=int, nargs="?", default=-1,
                    help="Start the first (if no argument is given) or the "
                         "specified experiment (this will also create the"
                         "experiment files)")

KIND_TRAINING = "T"
KIND_ROBUSTNESS = "R"
KINDS = [KIND_TRAINING, KIND_ROBUSTNESS]


ROBUSTNESS_TIME_LIMIT = "30m"
ROBUSTNESS_PARTITION = "infai_1"
ROBUSTNESS_MUTEX = "[]"
ROBUSTNESS_THRESHOLD = 0.8

ITERATIONS = 5  # Number of training/robustness iterations
# Required naming schema of this script
REGEX_NAME = re.compile(r"(\d\d\d\d-\d\d-\d\d-\S-)M(-.*)")
___EXP_NAME_MATCH = REGEX_NAME.match(__file__)
if ___EXP_NAME_MATCH is None:
    print("Invalid script name")
    sys.exit(1)
PREFIX, SUFFIX = ___EXP_NAME_MATCH.groups()
SEARCH_ENGINE = rl_experiment_parsing.get_sampling_engine(os.path.basename(__file__))
assert len(SEARCH_ENGINE) in [0, 2]
PREDICT_EXPANSIONS = len(SEARCH_ENGINE) == 2 and SEARCH_ENGINE[1].endswith("X")


with open(rl_constants.PATH_TEMPLATE_EXP_TRAINING, "r") as f:
    TEMPLATE_TRAINING = f.read()
with open(rl_constants.PATH_TEMPLATE_EXP_ROBUSTNESS, "r") as f:
    TEMPLATE_ROBUSTNESS = f.read()


def get_dir(path_exp, suffix=""):
    return os.path.join(
        os.path.dirname(__file__), "data", path_exp[:-3] + suffix)


def get_path_experiment_script(kind, iteration):
    return "{}{}{}{}".format(PREFIX, kind, iteration, SUFFIX)


STEPS = [(_i, _kind, get_path_experiment_script(_kind, _i))
         for _i in range(ITERATIONS) for _kind in KINDS]


def print_step(n, k, p, endl="\n", shift=""):
    print("{}{}\t{}\t{}".format(shift, n, k, p), end=endl)


def print_steps():
    print("Steps:")
    for n, (_, k, p) in enumerate(STEPS):
        print_step(n, k, p, shift="\t")


def check_exist_script(path_exp_script, options):
    if not os.path.exists(path_exp_script) or options.recreate:
        return check_exist_script.WRITE
    while True:
        answer = input(
            "The experiment script ({}) exists already. What shall we do? "
            "(e)xit/(r)e-create/(s)kip: ".format(path_exp_script)
        ).lower().strip()
        if answer in ["e", "exit"]:
            print("Exit")
            sys.exit(0)
        elif answer in ["r", "recreate"]:
            return check_exist_script.WRITE
        elif answer in ["s", "skip"]:
            return check_exist_script.SKIP


check_exist_script.WRITE = 1
check_exist_script.SKIP = 2


def check_previously_started(steps):
    shift_always = False
    shift = 0
    remove_always = False
    for no, (_, _, path) in enumerate(steps):
        out_dirs = [os.path.join(
            os.path.dirname(path), "data",
            os.path.basename(os.path.splitext(path)[0]) + suffix)
                    for suffix in ["", "-grid-steps", "-eval"]]
        out_dirs = [d for d in out_dirs if os.path.exists(d)]
        if len(out_dirs) > 0:
            if shift_always:
                shift = no + 1
            elif remove_always:
                for d in out_dirs:
                    shutil.rmtree(d)
            else:
                answer = input(
                    "Remenants of {} exist. (a)bort/(r)emove/(R)emove/(s)hift/"
                    "(S)hift: ".format(path)).strip()
                if answer in ["a", "abort"]:
                    print("abort")
                    sys.exit(0)
                elif answer.lower() in ["r", "remove"]:
                    for d in out_dirs:
                        shutil.rmtree(d)
                    if answer[0].isupper():
                        remove_always = True
                elif answer.lower() in ["s", "shift"]:
                    shift = no + 1
                    if answer[0].isupper():
                        shift_always = True
    return steps[shift:]


def create_experiment_scripts(options):
    print("Creating Experiments:")
    runs_training = None
    props_robustness = None

    for no, (i, kind, path) in enumerate(STEPS):
        print_step(no, kind, path, endl="...")
        action = check_exist_script(path, options)

        if kind == KIND_TRAINING:
            path_training = get_path_experiment_script(KIND_TRAINING, i)
            runs_training = get_dir(path_training)
            if action == check_exist_script.WRITE:
                with open(path_training, "w") as f:
                    f.write(TEMPLATE_TRAINING.format(
                        PREVIOUS_ROBUSTNESS_EXPERIMENT='(%s, %.1f)' % (
                            "None" if props_robustness is None
                            else '"%s"' % props_robustness,
                            ROBUSTNESS_THRESHOLD),
                        ADD_ROBUSTNESS_STEP=(i == 0),
                        NEXT_EXPERIMENT=("None" if len(STEPS) == no + 1
                                         else '"{}"'.format(STEPS[no + 1][2]))
                    ))
        elif kind == KIND_ROBUSTNESS:
            path_robustness = get_path_experiment_script(KIND_ROBUSTNESS, i)
            props_robustness = os.path.join(
                get_dir(path_robustness, suffix="-eval"), "properties")
            with open(path_robustness, "w") as f:
                f.write(TEMPLATE_ROBUSTNESS.format(
                    BENCHMARK_DIRECTORY=rl_constants.DIR_VALIDATION_ECAI,
                    EXPERIMENT_DIRECTORY='["%s"]' % runs_training,
                    CONFIGURATION_NAME=os.path.basename(
                        os.path.splitext(path_robustness)[0]),
                    OVERALL_TIME_LIMIT=ROBUSTNESS_TIME_LIMIT,
                    MUTEX_OPTIONS=ROBUSTNESS_MUTEX,
                    PARTITION=ROBUSTNESS_PARTITION,
                    NETWORK_PREDEFINITIONS=(
                        "rl_experiment_factory.Predefinitions"
                        ".Regression_SAS_State_Network.value"
                        if not PREDICT_EXPANSIONS else
                        "rl_experiment_factory.Predefinitions"
                        ".Regression_SAS_State_Network_Expansions.value"
                    ),
                    NEXT_EXPERIMENT=("None" if len(STEPS) == no + 1
                                     else '"{}"'.format(STEPS[no + 1][2]))
                ))
        else:
            assert False
        if action == check_exist_script.WRITE:
            print("Written.")
        else:
            print("Existed.")


def start_experiment(options):
    assert options.start >= 0
    steps = STEPS[options.start:]
    missing_exp_scripts = [p for (_, _, p) in steps
                           if not os.path.isfile(p)]
    if len(missing_exp_scripts) > 0:
        print("Some experiment scripts have to be created: {}".format(
                "\n".join(missing_exp_scripts)))
        sys.exit(0)

    steps = check_previously_started(steps)
    subprocess.call([sys.executable, steps[0][2], "--all"])


def run(options):
    options.start = 0 if options.start is None else options.start
    if not options.create and not options.recreate and options.start < 0:
        print_steps()
        sys.exit(0)
    if options.create or options.recreate:
        create_experiment_scripts(options)

    if options.start >= 0:
        start_experiment(options)


if __name__ == "__main__":
    run(parser.parse_args(sys.argv[1:]))
