#!/usr/bin/env python
"""
./get_training_command.py --data-source test .uniform.data.gz
./get_training_command.py--data-source init_sat .uniform.sat,data.gz --data-format NonStatic_A_01 --maximum-data-memory 12 --hidden-layers-by-dependencies 1 --balance True --memory 15
"""
from __future__ import print_function
import sys
sys.path.append("../..")

from src.training import parser_tools
from src.training.parser import construct
from src.training.bridges import StateFormat
from src.training.networks import NetworkFormat, keras_networks

import argparse
import keras
import os
import re

CHOICE_STATE_FORMATS = []
for name in StateFormat.get_format_names():
    CHOICE_STATE_FORMATS.append(name)

CHOICE_NETWORK_FORMATS = []
for name in NetworkFormat.name2obj:
    CHOICE_NETWORK_FORMATS.append(name)


CALL = './fast-training.py ' \
       '"keras_dyn_mlp(tparams=ktparams' \
       '(epochs={EPOCHS},loss={LOSS},batch={BATCH},balance={BALANCE},' \
       'optimizer={OPTIMIZER},' \
       'callbacks={CALLBACKS}),hidden={HIDDEN},output_units={OUTPUT},' \
       'activation={ACTIVATION},dropout={DROPOUT},' \
       'hidden_layer_size={HIDDEN_LAYER_SIZE},' \
       'hidden_layers_by_dependencies={HIDDEN_LAYERS_BY_DEPENDENCIES},' \
       'x_fields={X_FIELDS},y_fields={Y_FIELDS},' \
       'formats={NETWORK_FORMATS},graphdef=graphdef.txt,count_samples=True)" ' \
       '--prefix {PREFIX} ' \
       '-d {DATA_DIRECTORY} ' \
       '-sdt ' \
       '{DATA_DIRECTORY_FILTER}' \
       '--input "gzip(suffix={DATA_SUFFIX})" ' \
       '--format {DATA_FORMAT} ' \
       '{DOMAIN_PROPERTIES}' \
       '{MAXIMUM_DATA_MEMORY}' \
       '--slurm ' \
       '--cross-validation 10 ' \
       '-a "--export=ALL {TRAINING_SCRIPT}" ' \
       '-o -n model ' \
       '--fields {ALL_FIELDS} ' \
       '--skip ' \
       '{SKIP_FLAGS}' \
       '--dry'


SKIP_FLAGS = '--skip-if-trained --skip-if-flag --skip-if-running '

CALLBACKS = ['keras_model_checkpoint(val_loss,./checkpoint.ckp)',
             'keras_progress_checking(acc,100,0.001,False)',
             'keras_early_stopping(val_loss,0.01,100)',
             'keras_restart(-1,stop_successful=True)',
             'keras_stoptimer(max_time=86400,per_training=False,'
             'prevent_reinit=True,timeout_as_failure=True)']

parser = argparse.ArgumentParser()


def add_argument(name, type, default, help, action="store", choices=None,
                 nargs=None, required=False):
    global parser
    parser.add_argument(name, action=action, type=type, default=default,
                        choices=choices, nargs=nargs, help=help,
                        required=required)


# Training Arguments
add_argument("--epochs", int, 1000, "Number of training epochs to run")
add_argument("--loss", str, "mean_squared_error", "Training loss to use")
add_argument("--batch", int, 100, "Training batch size")
add_argument("--balance", str, "False", "Balance classes during training",
             choices=["True", "False"])
add_argument("--optimizer", str, "adam", "Optimizer to use for training")
add_argument("--callbacks", str, None, help="Callback to add to the network",
             action="append")

# Network Arguments
add_argument("--activation", str, "sigmoid", "Hidden activation functions.")
add_argument("--hidden", int, 5, "Hidden layers of the network")
add_argument("--hidden-layer-size", int, 0,
             "Number of neurons per hidden layer. -1 = Take size of input, "
             "0 = Scale dynamically from input size to output size down,"
             ">0 = use the given value")
add_argument("--hidden-layers-by-dependencies", int, 0,
             "0 = deactivates.\n 1 = Dependencies from Atom A to B if A in pre "
             "and B in post of an action.")
add_argument("--output", int, -1, "Number of output units (or -1 for regression"
                                  " and -2 for classification with from data"
                                  " deduced number of classes")
add_argument("--dropout", float, 0, "Dropout probability")
add_argument("--x-fields", str, ["current_state", "goals"], nargs="+",
             help="Fields to use for the networks feature input")
add_argument("--y-fields", str, ["hplan"], nargs="+",
             help="Fields to use for the networks label input")
add_argument("--network-formats", str, ["hdf5", "protobuf"], nargs="+",
             choices=CHOICE_NETWORK_FORMATS,
             help="Formats in which to store the network")

# Input Arguments
add_argument("--data-directory", str, "../DeePDown/data/FixedWorlds/opt",
             "Directory where the input data is located (possibly in subdirs).")
add_argument("--data-directory-filter", str, [], action="append",
             help="Regex for directories to train from has to match")
add_argument("--data-source", str, None,
             "First a name describing the sample type (e.g. inter_sat, init,..."
             "), second the suffix for the gzip input stream.",
             nargs=2, required=True)
add_argument("--data-format", str, StateFormat.All_A_01.name,
             choices=CHOICE_STATE_FORMATS,
             help="State format for the input")
add_argument("--maximum-data-memory", int, None,
             "Soft maximum amount of memory in GB into which data is loaded.")

add_argument("--memory", int, 7, "Memory to allow by slurm for training",
             choices=[7, 15, 16])

# Miscellaneous Arguments
parser.add_argument("--no-skip-flags", action="store_true",
                    help="Do not add flags to skip training if previously "
                         "trained or currently training or flag that training "
                         "failed is set.")


def check_threshold(name, errors, value, tmin=None, tmax=None,
                    allow_none=False):
    assert tmin is not None or tmax is not None
    msg = "%s has to be " % name
    if tmin is None:
        msg += "at most %i" % tmax
    elif tmax is None:
        msg += "at least %i" % tmin
    else:
        msg += "between (including) %i and %i" % (tmin, tmax)
    msg += " and may " + ("" if allow_none else "not ") + "be None: %s" % str(value)

    if value is None:
        if not allow_none:
            errors.append(msg)
    elif tmin is not None and value < tmin:
        errors.append(msg)
    elif tmax is not None and value > tmax:
        errors.append(msg)


def check_keras_func(name, errors, func, value):
    try:
        func(value)
    except ValueError:
        errors.append("%s unknown to keras: %s" % (name, value))


def check_construct(name, errors, clazz, *definitions):
    for definition in definitions:
        try:
            construct(
                parser_tools.ItemCache(),
                parser_tools.main_register.get_register(clazz),
                definition)
        except parser_tools.ArgumentException:
            errors.append("%s cannot be constructed by: %s" % (name, definition))


def parse_argv(argv):
    options = parser.parse_args(argv)
    errors = []

    check_threshold("Epochs", errors, options.epochs, tmin=1)
    check_keras_func("Loss", errors, keras.losses.get, options.loss)
    check_threshold("Batch", errors, options.batch, tmin=1)
    check_keras_func("Optimizer", errors, keras.optimizers.get,
                     options.optimizer)
    options.callbacks = CALLBACKS if options.callbacks is None else options.callbacks
    check_construct("Callback", errors, keras_networks.BaseKerasCallback,
                    *options.callbacks)
    check_keras_func("Activation", errors, keras.activations.get, options.activation)
    check_threshold("Hidden layers", errors, options.hidden, tmin=0)
    check_threshold("Hidden layer size", errors, options.hidden_layer_size, tmin=-1)
    check_threshold("Output units", errors, options.output, tmin=-2)
    check_threshold("Dropout", errors, options.dropout, tmin=0, tmax=1)
    for nformat in options.network_formats:
        check_keras_func("Network format", errors, NetworkFormat.by_name,
                         nformat)
    if not os.path.isdir(options.data_directory):
        errors.append("Data directory does not exist or is not a directory.")
    for regex in options.data_directory_filter:
        try:
            re.compile(regex)
        except:
            errors("Unable to compile regular expression: %s" % regex)
    check_keras_func("Data format", errors, StateFormat.get,
                     options.data_format)

    if (options.hidden_layers_by_dependencies < 0 or
            options.hidden_layers_by_dependencies > 1):
        errors.append("Invalid value for \"--hidden-layers-by-dependencies\".")

    if options.hidden_layer_size != 0 and options.hidden_layers_by_dependencies != 0:
        errors.append("Conflicting hidden layer options. Use either "
                      "\"--hidden-layer-size INT\" or "
                      "\"--hidden-layers-by-dependencies INT\".")

    if (options.maximum_data_memory is not None
            and options.maximum_data_memory <= 0):
        errors.append("--maximum-data-memory has to be unlimited (= argument "
                      "not specified) or positive.")

    if len(errors) != 0:
        print("\n".join(errors), file=sys.stderr)
        sys.exit(1)
    return options


def get_prefix(params):
    params = params.copy()
    params["BALANCED"] = "bal" if params["BALANCE"] == "True" else "ubal"

    return ("{NETWORK_TYPE}_{INPUT_FORMAT_TYPE}_{BALANCED}_h{HIDDEN}_"
            "{ACTIVATION}_{HIDDEN_LAYER_TYPE}{SAMPLE_TYPE}_drp{DROPOUT}_").format(**params)

def construct_command(options):
    params = {key.upper(): value for key, value in options.__dict__.items()}
    add_domain_properties = False
    # Make list strings
    params["CALLBACKS"] = "[%s]" % ",".join(options.callbacks)
    params["X_FIELDS"] = "[%s]" % ",".join(options.x_fields)
    params["Y_FIELDS"] = "[%s]" % ",".join(options.y_fields)
    params["NETWORK_FORMATS"] = "[%s]" % ",".join(options.network_formats)
    params["ALL_FIELDS"] = " ".join(
        set(options.x_fields + options.y_fields))

    if options.maximum_data_memory is None:
        params["MAXIMUM_DATA_MEMORY"] = ""
    else:
        params["MAXIMUM_DATA_MEMORY"] = ("--maximum-data-memory %iGB "
                                         % params["MAXIMUM_DATA_MEMORY"])
    params["DATA_DIRECTORY_FILTER"] = "".join(
        ["--directory-filter %s " % f for f in options.data_directory_filter])
    if options.data_format.startswith("NonStatic"):
        params["INPUT_FORMAT_TYPE"] = "ns"
        add_domain_properties = True
    else:
        params["INPUT_FORMAT_TYPE"] = "full"
    if options.memory == 7:
        params["TRAINING_SCRIPT"] = "./misc/slurm/slurm-training.sh"
    else:
        params["TRAINING_SCRIPT"] = (
            "./misc/slurm/slurm-training-%igb.sh" % options.memory)
    if options.hidden_layers_by_dependencies == 1:
        add_domain_properties = True
    if options.hidden_layer_size != 0:
        params["HIDDEN_LAYER_TYPE"] = "hls%i_" % options.hidden_layer_size
    elif options.hidden_layers_by_dependencies != 0:
        params["HIDDEN_LAYER_TYPE"] = "hld%i_" % options.hidden_layers_by_dependencies
    else:
        params["HIDDEN_LAYER_TYPE"] = ""

    params["SKIP_FLAGS"] = "" if options.no_skip_flags else SKIP_FLAGS
    params["SAMPLE_TYPE"] = options.data_source[0]
    params["DATA_SUFFIX"] = options.data_source[1]
    params["NETWORK_TYPE"] = "reg" if params["OUTPUT"] == -1 else "clas"
    params["PREFIX"] = get_prefix(params)

    if add_domain_properties:
        params["DOMAIN_PROPERTIES"] = "--domain-properties "

    return CALL.format(**params)


def run(argv):
    options = parse_argv(argv)
    command = construct_command(options)
    print(command)


if __name__ == "__main__":
    run(sys.argv[1:])
