#!/usr/bin/env python
"""
Example call
keras_mlp(tparams=ktparams(epochs=99999999,epoch_verbosity=1,rl_batch_generator=replay,batch=250,callbacks=[keras_model_saver(threshold=0.1,min_delay=50,step=train_end)],optimizer=adam),y_fields=0,x_fields=1,load=model,store=model,hidden_layer_size=[250,250],learner_formats=[pb,h5],output_units=-1,batch_normalization=1,l2=0,dropout=0) /home/ferber/repositories/benchmarks/blocks/probBLOCKS-8-0.pddl --fast-downward-build debug64dynamic --lookahead 2 --max-scrambles 200 --maximum-training-time 21600s --add-final-evaluation --increase-scrambling keras_condition_counter(flags=inc_scrambling) 1 5 --add-initial-evaluation default 300 --add-intermediate-evaluations default 600 180
keras_mlp(tparams=ktparams(epochs=99999999,epoch_verbosity=1,rl_batch_generator=replay,batch=250,callbacks=[keras_model_saver(threshold=0.1,min_delay=50,step=train_end,flags=[model_saved]),keras_model_saver(threshold=0.1,min_delay=50,step=train_end,min_time=3600,add_model_indices=True)],optimizer=adam),y_fields=0,x_fields=1,load=model,store=model,learner_formats=[pb,h5],hidden_layer_size=[250,250,keras_residual_block(hidden_layer_count=2,hidden_layer_size=250)],output_units=-1,batch_normalization=1,l2=0,dropout=0) /home/ferber/repositories/benchmarks/depot/p01.pddl --fast-downward-build debug64dynamic --lookahead 2 --max-scrambles 200 --maximum-training-time 21600s --add-final-evaluation --increase-scrambling keras_condition_counter(flags=inc_scrambling) 1 5 --working-directory tmp --add-intermediate-evaluations default 10 5 intermediatelogs.zip --add-initial-evaluation default 4 init.log --cold-start-evaluator ff keras_condition_counter(min_time=10) 0.10
'keras_mlp(tparams=ktparams(epochs=99999999,epoch_verbosity=50,rl_batch_generator=replay,batch=250,callbacks=[keras_model_saver(threshold=0.1,min_delay=50,step=train_end,flags=[model_saved]),keras_model_saver(threshold=0.1,min_delay=50,step=train_end,min_time=3600,add_model_indices=True)],optimizer=adam),y_fields=0,x_fields=1,load=model,store=model,learner_formats=[pb,h5],hidden_layer_size=[250,250,keras_residual_block(hidden_layer_count=2,hidden_layer_size=250)],output_units=-1,batch_normalization=1,l2=0,dropout=0) problem.pddl --domain-pddl domain.pddl --fast-downward-build release64dynamic --lookahead 2 --maximum-training-time 21600s --add-final-evaluation --reinitialize-after-time 0.75h --max-scrambles 200 --cold-start-evaluator ff keras_condition_counter(min_time=720) 0.1 1.0 0.0 --add-initial-evaluation default 300 --add-intermediate-evaluations default 600 180'

"""

from __future__ import print_function, division

from src.training.learners.keras_networks import KerasMLP, StopTraining, KerasNetwork
from src.training.learners.keras_networks.keras_callbacks import BaseKerasConditionExecutor
from tools import misc as tm
from tools import paths as tpa
from tools import parsing as apt

import argparse
from collections import defaultdict
import enum
import json
import math
import matplotlib as mpl
mpl.use('agg')
import matplotlib.pyplot as plt
from multiprocessing import Process, Queue, RLock
import numpy as np
import os
import psutil
import random
import re
import shlex
import signal
import sys
import time
import traceback
import zipfile

if sys.version_info < (3,):
    import Queue as queue
    import subprocess32 as subprocess

    FileNotFoundError = IOError

    def decoder(s):
        return s.decode()
    from StringIO import StringIO
else:
    import queue
    import subprocess

    def decoder(s):
        return s
    from io import StringIO

class SyncToken(enum.Enum):
    ModelCreated = "ModelCreated"
    """Externally done. Restarts Fast Downward sampling with new parameters"""
    IncreaseScrambling = "IncreaseScrambling"
    """Internally done during the Fast Downward sampling run."""
    InternalScramblingIncrement = "InternalScramblingIncrement"
    Terminate = "Terminate"
    SolvedAndTerminate = "SolvedAndTerminate"
    Solved = "Solved"
    Unsolved = "Unsolved"
    Error = "Error"
    DecreaseColdStartWeight = "DecreaseColdStartWeight"

REGEX_SAMPLING_SEARCH_FEEDBACK = re.compile(r"#Solves tasks generated by sampling technique (\d+):.*max=(\d+).*,")
NB_SEEDS = 10
MAX_SAMPLE_QUEUE_SIZE = 20

TIME_SLEEP_SAMPLING = 2
TIME_SLEEP_LOADING = 1
TIME_SLEEP_DISK_DELAY = 1
TIME_SLEEP_SHUTDOWN = max(TIME_SLEEP_SAMPLING, TIME_SLEEP_LOADING) + 1


# Indicies
IDX_COLD_START_EVALUATOR = 0
IDX_COLD_START_CALLBACK = 1
IDX_COLD_START_STEP = 2
IDX_COLD_START_INIT = 3
IDX_COLD_START_MIN = 4

SAMPLE_FILE_PREFIX = "samples_"
KEY_MAX_SCRAMBLES = "max_scrambles"


TEMPLATE_IFORWARD_NONE_TECHNIQUE = (
    "iforward_none({task_count},distribution=uniform_int_dist(0, {%s},"
    "random_seed={seed}),random_seed={seed},"
    "deprioritize_undoing_steps=true,"
    "bias={bias},bias_probabilistic={bias_mode_probabilistic},"
    "bias_reload_frequency={bias_reload_frequency},bias_adapt={bias_adapt}"
    ")" % KEY_MAX_SCRAMBLES
)

TEMPLATE_GBACKWARD_NONE_TECHNIQUE = (
    "gbackward_none({task_count},distribution=uniform_int_dist("
    "0,{%s},upgrade_max={upgrade_max|0},random_seed={seed}),random_seed={seed},"
    "deprioritize_undoing_steps=true,"
    "wrap_partial_assignment={wrap_partial_assignment},"
    "is_valid_walk=true,"
    "bias={bias},bias_probabilistic={bias_mode_probabilistic},"
    "bias_reload_frequency={bias_reload_frequency},bias_adapt={bias_adapt},"
    "max_upgrades={nb_upgrades|0})" % KEY_MAX_SCRAMBLES
)

TEMPLATE_GBACKWARD_NONE_TECHNIQUE_NO_BIAS = (
    TEMPLATE_GBACKWARD_NONE_TECHNIQUE.
        replace("{bias}", "<none>").
        replace("{bias_mode_probabilistic}", "false").
        replace("{bias_reload_frequency}", "-1"))

TEMPLATE_UNIFORM_NONE_TECHNIQUE = (
    "uniform_none({task_count},evals=[ff(transform=sampling_transform)],"
    "random_seed={seed})"
)

TEMPLATE_SAMPLING_Q_CONFIGURATION = (
    "sampling_q({evaluator},"
    "lookahead={lookahead},"
    "evaluator_reload_frequency={cache_size},"
    "task_reload_frequency={cache_size},"
    "techniques=[{sampling_techniques}],"
    "sample_cache_size={cache_size},"
    "iterate_sample_files=true,"
    "index_sample_files={cache_index},"
    "max_sample_files={max_sample_files},"
    "sample_format=csv,state_format=fdr,field_separator=;,state_separator=;,"
    "add_goal=false,"
    "skip_undefined_facts={skip_undefined_facts},shuffle_techniques=true,"
    "expand_goal={expand_goal},reload_expanded_goals={reload_expanded_goals},"
    "expand_goal_state_limit={expand_goal_state_limit},"
    "random_seed={seed})"
)

BASE_TEMPLATE_SAMPLING_SEARCH_CONFIGURATION = (
    "sampling_search({search_engine}({evaluator},transform={TASK_TRANSFORMATION},"
    "max_time={max_time}),"
    "techniques=[{sampling_techniques}],"
    "sample_cache_size={cache_size},"
    "network_reload_frequency={cache_size},"
    "iterate_sample_files=true,"
    "index_sample_files={cache_index},"
    "max_sample_files={max_sample_files},"
    "sample_format=csv,state_format=fdr,field_separator=;,state_separator=;,"
    "store_solution_trajectory={store_solution_trajectory},"
    "store_other_trajectories=false,"
    "store_all_states=false,"
    "store_initial_state=false,store_intermediate_state=false,"
    "skip_goal_field=true,skip_second_state_field=true,skip_action_field=true,"
    "add_unsolved_samples={add_unsolved_samples},"
    "store_expansions={store_expansions},"
    "store_expansions_unsolved={store_expansions_unsolved},"
    "skip_undefined_facts={skip_undefined_facts},shuffle_techniques=true,"
    "random_seed={seed})"
)


def instantiate_template(templates, func):
    instantiations = {}
    for key, configuration in templates.items():
        func(instantiations, key, configuration)
    return instantiations


def instantiate_search_engines(templates, key, configuration):
    for _key_engine, _engine in [("GBFS", "eager_greedy"), ("ASTAR", "astar")]:
        templates[key.replace("{engine}", _key_engine)] = (
            configuration
                .replace("{search_engine}", _engine)
                .replace("{evaluator}",
                         "{evaluator}" if _engine == "astar" else "[{evaluator}]")
        )


def instantiate_timeout(templates, key, configuration):
    for _key_timeout, _timeout in [("", "infinity"), ("1", "1s"), ("10", "10s")]:
        templates[key.replace("{timeout}", _key_timeout)] = (
            configuration.replace("{max_time}", _timeout))


def instantiate_unsolved(templates, key, configuration):
    for _key_unsolved, _unsolved in [("", "false"), ("U", "true")]:
        templates[key.replace("{add_unsolved}", _key_unsolved)] = (
            configuration.replace("{add_unsolved_samples}", _unsolved))


def instantiate_expansions(templates, key, configuration):
    for _key_expansions, _expansions, _solution_trajectory in [
            ("", "false", "true"), ("X", "true", "false")]:
        templates[key.replace("{learn_expansions}", _key_expansions)] = (
            configuration
                .replace("{store_expansions}", _expansions)
                .replace("{store_solution_trajectory}", _solution_trajectory))


def instantiate_expansions_unsolved(templates, key, configuration):
    _OPTIONS_EXPANSIONS_UNSOLVED = ([("", "false")] if key.find("X") == -1 else
                                    [("", "false"), ("X", "true")])
    for _key_expansions_unsolved, _expansions_unsolved in _OPTIONS_EXPANSIONS_UNSOLVED:
        templates[key.replace("{add_final_expansions}", _key_expansions_unsolved)] = (
            configuration.replace("{store_expansions_unsolved}", _expansions_unsolved)
        )

TEMPLATES_SAMPLING_ENGINES = {"{engine}{timeout}{add_unsolved}{learn_expansions}{add_final_expansions}": BASE_TEMPLATE_SAMPLING_SEARCH_CONFIGURATION}
for _func in [instantiate_search_engines, instantiate_timeout,
              instantiate_unsolved, instantiate_expansions,
              instantiate_expansions_unsolved]:
    TEMPLATES_SAMPLING_ENGINES = instantiate_template(TEMPLATES_SAMPLING_ENGINES, _func)
TEMPLATES_SAMPLING_ENGINES["V"] = TEMPLATE_SAMPLING_Q_CONFIGURATION


TEMPLATE_SEARCH_CONFIGURATION = "eager_greedy([{evaluator}],cost_type=one)"

DEFAULT_NN_DEFINITION = (
    "snet(type={OUTPUT_TYPE},unary_threshold={UNARY_THRESHOLD},"
    "state_layer=input_1,"
    "path={MODEL_FILE}.pb,transform={TASK_TRANSFORMATION},bin_size={BIN_SIZE},"
    "output_layers={MODEL_OUTPUT_LAYER},domain_max_is_undefined="
    "{domain_max_is_undefined},exponentiate_heuristic={exponentiate_heuristic},"
    "random_seed={seed})")

DEFAULT_NN_EVALUATOR = (
    "nh({NETWORK_PRE_DEFINITION},transform={TASK_TRANSFORMATION})")

MAX_NN_HFF_EVALUATOR = (
    "max([ff(transform={TASK_TRANSFORMATION}),%s])" % DEFAULT_NN_EVALUATOR)

DEFAULT_INITIAL_SEARCH = "astar(blind())"

EVALUATOR_CONFIGURATION_HINT = (
    "If the learner is of type KerasMLP, then the evaluator can use the "
    "following keys for an automatic configuration: {OUTPUT_TYPE}: regression/"
    "classification, {UNARY_THRESHOLD}: unary threshold of the network, "
    "{MODEL_FILE}: path to the model file, {TASK_TRANSFORMATION}: sets to"
    "sampling_transform(), {BIN_SIZE}: network bin size, {MODEL_OUTPUT_LAYER}:"
    "{MODEL_OUTPUT_LAYER, <model_path>.pb")


parser = argparse.ArgumentParser(
    "things currently to have in mind:"
    "1. Describe in the learner configuration in which formats "
    "to store the model. 2. Set in the learner configuration the x/y_fields to "
    "inform about the numbers of input/outputs. 3. Adapt the evaluator "
    "configuration to look at the right location and input/output layers of "
    "the network stored.", add_help=True)


class EvaluationOutputWrapper(object):
    status = property(lambda self: self._get_status())
    message = property(lambda self: self._get_message())
    finished = property(lambda self: self._get_finished())

    def __init__(self, data_queue):
        self._queue = data_queue
        self._status = None
        self._message = None

    def _fetch(self):
        if self._status is None and not self._queue.empty():
            self._status, self._message = self._queue.get()

    def _get_status(self):
        self._fetch()
        return self._status

    def _get_message(self):
        self._fetch()
        return self._message

    def _get_finished(self):
        self._fetch()
        return self._status is not None


parser.add_argument("learner", type=apt.learner,
                    help="Learner configuration to train.")
parser.add_argument("task_pddl", type=apt.absfile,
                    help="Path to the domain file")
parser.add_argument("--domain-pddl", type=apt.absfile,
                    help="Path to the task file. Provide if Fast-Downward"
                         "automatic domain file detection fails.")
parser.add_argument("--preprocessor", type=str, action="append",
                    default=[],
                    help="Command to execute after task translation. You may "
                         "expect an output.sas file containing the SAS task.")
parser.add_argument("--translator-options", action="append", default=[],
                    help="Options to provide to the translation from task to"
                         "initial sas file. Each argument given to this is later"
                         "split like the shell would split it. Thus provide multiple "
                         "arguments just as a single string. You argument may noy"
                         "start with -- or -, thus, just start with a white "
                         "space if than -, -- if this is required by you.")
parser.add_argument("--working-directory",
                    type=apt.allow_environment_variable(apt.absdir),
                    default=".",
                    help="Change the working directory")
parser.add_argument("--maximum-training-time", type=apt.time, default=1800,
                    help="Maximum number of seconds to train for (just limits"
                         "training time not overhead prior and after)")
parser.add_argument("--add-initial-evaluation", nargs="+",
                    type=apt.arg_counts(
                        "--add-initial-evaluation",
                        apt.list_type(
                            apt.placeholder_type(
                                apt.has_seeds, {"default": DEFAULT_INITIAL_SEARCH}),
                            apt.time,
                            os.path.abspath),
                        intervals=[(2, 3)]),
                    default=False,
                    help="Starts an initial evaluation together with the "
                         "training. If the evaluation finds a plan and no path"
                         "is given, the training process is stopped else if a "
                         "plan is found and a path is given, the search output"
                         "is stored in the path and the training continues."
                         " Specify this options as : "
                         "[search configuration|default (aka astar(blind)] "
                         "[max search time limit] (path)")
parser.add_argument("--add-intermediate-evaluations", nargs="+",
                    type=apt.arg_counts(
                        "--add-intermediate-evaluations",
                        apt.list_type(
                            apt.placeholder_type(
                                apt.has_seeds,
                                {"default": TEMPLATE_SEARCH_CONFIGURATION}),
                            apt.time,
                            apt.time,
                            os.path.abspath),
                        intervals=[(3,4)]),
                    default=False,
                    help="Adds during training evaluation runs. If such a run"
                         "find a plan and not path is specified, then the "
                         "training process stops else if a path is specified,"
                         "then the search output is stored at path in a zip "
                         "archive as file with \"run#\". Specify "
                         "this options as : "
                         "[search configuration|default] "
                         "[min time between runs] [max search time limit] "
                         "(path)")
parser.add_argument("--add-final-evaluation", type=apt.has_seeds, nargs="?",
                    default=False,
                    help='Adds a search run after the training process for '
                         '1800s.You can use {evaluator} to insert the '
                         'evaluator configuration used during training')
parser.add_argument("--reinitialize-after-time", type=apt.time,
                    default=None,
                    help="Reinitialize the NN if after a fix set of time, the"
                         "NN does not predict values greater or equal to 1.")
parser.add_argument("--load-initial-model", nargs="*", type=int, default=(-1, -1),
                    help="Loads a model when starting up (model has to exist)."
                         "Either specify no options or provide start epoch"
                         "and start time in seconds.")
parser_data = parser.add_argument_group("Data Generator Arguments")
parser.add_argument("--buffer-size-factor", type=apt.int_positive,
                    default=20,
                    help="if using a replay buffer, then the size of the buffer"
                         "is batch_size * buffer_size_factor.")
parser_data.add_argument("--lookahead", type=apt.int_positive, default=1,
                         help="Look ahead for the V value calculation.")
parser_data.add_argument("--max-data-files", type=apt.int_positive,
                         default=200,
                         help="Maximum number of data files desired on disk.")
parser_data.add_argument("--max-cached-batches", type=apt.int_positive,
                         default=10,
                         help="Maximum number of sample files to have "
                              "loaded and send to the caching queue without "
                              "being used.")
parser_data.add_argument("--samples-per-data-file", type=apt.int_positive,
                         default=250,
                         help="Number of data samples to store in a data file.")
# During the script this options has  the format:
# ((training definition, search definitions), predefinition key?)
parser_data.add_argument(
    "--fd-network-definition", nargs="*", type=apt.arg_counts(
        "--fd-network-definition",
        apt.list_type(apt.has_seeds, str),
        intervals=[(1, 2)]
    ),
    default=(apt.has_seeds(DEFAULT_NN_DEFINITION), "hnn"),
    help="Definition of the neural network as Fast Downward component given as"
         "[FD NEURAL NETWORK DEFINITION] (PREDEFINITION KEY). If "
         "PREDEFINITION_KEY is provided, then the network is predefined. Some"
         "options allow inserting the definition/predefinition via "
         "{NETWORK_PRE_DEFINITION}."
)
parser_data.add_argument(
    "--v-value-evaluator", "--evaluator", type=apt.placeholder_type(
        apt.has_seeds, {
            "NN": DEFAULT_NN_EVALUATOR,
            "MaxNNhFF": MAX_NN_HFF_EVALUATOR
        }),
    default=apt.has_seeds(DEFAULT_NN_EVALUATOR),
    help="Evaluator for Fast Downward to calculate the V values. Normally, "
         "this evaluator should use the used network. use {seed3} if some "
         "random seed shall be set. Remember: use "
         "sampling_transform() to forward the current task to the evaluator. "
         "%s" % EVALUATOR_CONFIGURATION_HINT)

parser_data.add_argument(
    "--cold-start-evaluator", nargs="+", type=apt.arg_counts(
        "--cold-start-evaluator",
        apt.list_type(
            apt.placeholder_type(
                apt.has_seeds, {"ff": "ff(transform={TASK_TRANSFORMATION})"}),
            apt.callback,
            apt.float_interval(0, 1),
            apt.float_interval(0, 1),
            apt.float_interval(0, 1)
        ),
        intervals=[(3, 5)]
    ),
    default=None,
    help="Only for sampling! Gradually move weight from cold-start-evaluator to"
         "'--v-value-evaluator'. Define as: [COLD START EVALUATOR] [KERAS "
         "CONDITION TO DECREASE WEIGHT] [DECREASE STEP] (INITIAL WEIGHT ON COLD"
         "START) (MIN WEIGHT ON COLD START EVAL.] Weights can (and most times "
         "should) be floats.Remember: use "
         "sampling_transform() to forward the current task to the evaluator. "
         "%s" % EVALUATOR_CONFIGURATION_HINT)

parser_data.add_argument(
    "--sampling-engine", type=apt.placeholder_type(
        apt.has_seeds, TEMPLATES_SAMPLING_ENGINES),
    default=apt.has_seeds(TEMPLATE_SAMPLING_Q_CONFIGURATION),
    help="The sampling engine to generate the training data (e.g. a v value "
         "calculating engine or a search). Use 'V' to select the default V "
         "value producing engine and 'S' to use the default search.")
parser_data.add_argument(
    "--sampling-technique", nargs="+", action="append", default=[],
    help="Define Sampling techniques to use. Multiple techniques can be"
         "defined. Arguments are: <TODO>")
parser_data.add_argument("--expand-goal", nargs=3, default=["0", "true", "-1"],
                         type=apt.list_type(
                             int,
                             apt.choice_type(["true", "false"]),
                             int
                         ),
                         help="sets the {expand_goal} and {reload_expanded_goals}"
                              "{expand_goal_state_limit}"
                              "keys. not all search techniques support this."
                         )
parser_data.add_argument("--fast-downward-directory", type=apt.absdir,
                         default=os.path.dirname(__file__),
                         help="Path to the directory containing Fast-Downward.")
parser_data.add_argument("--fast-downward-build", default="debug64dynamic",
                         choices=["debug64dynamic", "release64dynamic"],
                         help="Fast Downward build to use.")
parser_data.add_argument("--random-seeds", action="store_true",
                         help="Assign the sampling algorithm random seeds ("
                              "otherwise the seeds are deterministically given,"
                              "meaning the same sample file content will be"
                              "generated, but the file names can differ due to"
                              "multiprocessing.).")
parser_data.add_argument("--max-generator-processes", type=apt.int_positive,
                         default=4,
                         help="Maximum number of data generating processes to "
                              "run in parallel.")
parser_data.add_argument("--wrap-partial-assignment", type=apt.int_positive,
                         default=None,
                         help="Specifies that the sampling produces and stores "
                              "partial assignments which will be completed "
                              "during training. This options sets the "
                              "following flags to true (which are otherwise "
                              "false): in all sampling techniques"
                              " {wrap_partial_assignment};"
                              " {skip_undefined_facts} in all sampling engines;"
                              " and {domain_max_is_undefined} in the evaluator "
                              "given to FD.")
parser_data.add_argument("--samples-have-timeout-information",
                         action="store_true",
                         help="If set, then one column of the sample csv file"
                              "is interpreted as whether the sample was solved "
                              "(1) or not (0). If we detect "
                              "'add_unsolved_samples=true', we set this"
                              "variable automatically to true.")
parser_data.add_argument(
    "--transform-label", default=[None], nargs="+",
    type=apt.list_type(apt.transform_numeric, apt.split_type(":"),
                       mode=apt.ListModes.Last),
    help="Transform labels for training. Additionally to the transformation"
         "name (enter an invalid transformation and we show all possible "
         "transformations) you may add any number of 'key:value' pairs. Those"
         "keys will be replaced in the network definition by the specified"
         "value (e.g., use this to inform the network to transform back the"
         "predicted values)")
parser_samptech = argparse.ArgumentParser(prefix_chars="+")
parser_samptech.add_argument(
    "sampling_technique",
    type=apt.placeholder_type(apt.has_seeds, {
     "iforward_none": TEMPLATE_IFORWARD_NONE_TECHNIQUE,
     "gbackward_none": TEMPLATE_GBACKWARD_NONE_TECHNIQUE,
     "gbackward_none_nb": TEMPLATE_GBACKWARD_NONE_TECHNIQUE_NO_BIAS,
     "uniform_none": TEMPLATE_UNIFORM_NONE_TECHNIQUE
    }),
    help="Technique to use for generating the samples. Sampling techniques can "
         "be fully defined or contain the following formatters: {task_count}: "
         "how many tasks they shall generate, {%s}: how often they shall "
         "scramble the state, {seed} random seeds." % KEY_MAX_SCRAMBLES)
parser_samptech.add_argument(
    "++weight", type=apt.int_positive, default=1,
     help="If multiple techniques are used and the {task_count} is a formatter,"
          " then this deterines the ratio between the techniques "
          "(weight/sum of weights)")
parser_samptech.add_argument(
    "++bias", nargs=4, default=["<none>", "false", "-1", "-1"],
    type=apt.list_type(
        apt.has_seeds,
        apt.choice_type(["true", "false"]),
        apt.int_interval(-1, None),
        float
    ),
    help="Replaces the following keywords in the sampling technique definition:"
         " {bias}, {bias_mode_probabilistic}, {bias_reload_frequency},"
         " {bias_adapt}. Usage: "
         "[BIAS EVALUATOR] [true|false] [RELOAD FREQUENCY] [ADAPT_BIAS]")
parser_samptech.add_argument(
    "++max-scrambles", type=apt.int_zero_positive, default=200,
    help="Sets the {%s} formatter in the sampling techniques.")

parser_samptech.add_argument(
    "++active", type=apt.callback, default=None,
    help="Condition when to use this sampling technique. If not given, the "
         "sampling technique is always active.")
parser_samptech.add_argument(
    "++increase-scrambling", nargs=2, default=None,
    type=apt.list_type(
        apt.callback,
        apt.float_modification
    ),
    help="The scramble formatter starts with the value from --max-scrambles."
         "Everytime the condition given is true, the scrambling value is "
         "modified by the given modificator. Use as modificator "
         "'(\\+|-|\\*|/)?(\\d+)'"
)

parser_samptech.add_argument(
    "++upgrade", nargs="+", default=None,
    help="[Maximum nb of upgrades allowed] [KEY=VALUE]*. Key is the formatting"
         "key used in the sampling technique template, e.g. {upgrade_KEY}"
)
""" -------------------------- General Methods ------------------------------"""


def translate(file_sas, file_task, file_domain, script_fast_downward, build,
              translator_options, preprocessors):
    assert not os.path.exists(file_sas), file_sas
    tm.translate(file_task, file_domain, script_fast_downward, build=build,
                 translator_options=translator_options)
    assert os.path.isfile(file_sas), file_sas
    print("Run Preprocessors")
    for preprocessor in preprocessors:
        print(">>>", preprocessor)
        subprocess.call(preprocessor)
        assert os.path.isfile(file_sas), file_sas


def extract_used_facts(file_sas, file_used_facts):
    with open(file_sas, "r") as f:
        sas = f.read()
    pattern_variable = re.compile(
r"""begin_variable
var\d+
-1
\d+
(([^\n]+\n)+?)end_variable""", re.MULTILINE)
    atoms = [vv.strip()
             for v in pattern_variable.findall(sas)
             for vv in v[0].split("\n") if vv.strip() != ""]
    with open(file_used_facts, "w") as f:
        json.dump(atoms, f)


def setup(working_directory, script_fast_downward, file_domain, file_task,
          file_sas, file_used_facts,
          dir_data, build, translator_options, preprocessors,
          initial_evaluation, intermediate_evaluations):
    if working_directory is not None:
        os.chdir(working_directory)

    if initial_evaluation is not False and len(initial_evaluation) > 2:
        tpa.remove(initial_evaluation[2], missing_ok=True)
    if (intermediate_evaluations is not False and
            len(intermediate_evaluations) > 3):
        tpa.remove(intermediate_evaluations[3], missing_ok=True)

    tpa.remove(file_sas, missing_ok=True)
    tpa.remove(file_used_facts, missing_ok=True)
    tpa.remove_dir(dir_data, missing_ok=True)
    tpa.remove_dir(os.path.join(os.path.dirname(dir_data), "err_dir"),
                   missing_ok=True)

    translate(file_sas, file_task, file_domain, script_fast_downward,
              build, translator_options, preprocessors)
    extract_used_facts(file_sas, file_used_facts)
    try:
        os.makedirs(dir_data)
    except OSError as e:
        if e.errno != 17:
            raise
    assert os.path.isdir(dir_data), dir_data


def shutdown(sync_to_sampling, sync_sampling_process_ids,
             proc_sampling, proc_loading_supervisor, procs_loading,
             proc_intermediate_search, proc_initial_search,
             working_directory, dir_data):
    def print_shutdown(msg):
        print("Shutdown>", msg)
    print_shutdown("Start...")
    sys.stdout.flush()
    sys.stderr.flush()
    sync_to_sampling.put(SyncToken.Terminate)
    if proc_loading_supervisor.is_alive():
        proc_loading_supervisor.terminate()

    for proc in [proc_initial_search, proc_intermediate_search] + procs_loading:
        if proc is not None and proc.is_alive():
            tm.kill_process_and_children(proc.pid)

    if proc_sampling.is_alive():
        proc_sampling.join(TIME_SLEEP_SHUTDOWN)
    if proc_sampling.is_alive():
        print_shutdown("Sampling process still alive.")
        while not sync_sampling_process_ids.empty():
            tm.kill_process_and_children(sync_sampling_process_ids.get())
        proc_sampling.terminate()
        print_shutdown("Sampling process killed.")

    time.sleep(TIME_SLEEP_DISK_DELAY)
    print_shutdown("Delete temporary data..")
    tpa.remove_dir(dir_data, missing_ok=True)
    tpa.remove(os.path.join(working_directory, "model.graphdef"),
               missing_ok=True)
    print_shutdown("Done.")
    sys.stdout.flush()
    sys.stderr.flush()


def await_token(sync_queue, *tokens, **kwargs):
    # gets a single token and checks if it is expected
    block = kwargs.get("block", True)
    timeout = kwargs.get("timeout")
    token = sync_queue.get(block=block, timeout=timeout)
    if not isinstance(token, tuple):
        token = tuple([token])
    assert token[0] in tokens, \
        "recv: %s, expected: %s" % (token[0], str(tokens))
    return token


def receive_tokens(sync_queue, *tokens):
    # gets all tokens
    _tokens = []
    while True:
        try:
            _tokens.append(await_token(sync_queue, *tokens, block=False))
        except queue.Empty:
            break
    return _tokens

""" ------------------------- Sampling Methods-------------------------------"""


def get_sample_files(dir_data, prefix_sample_file):
    return [os.path.join(dir_data, x) for x in os.listdir(dir_data)
            if os.path.basename(x).startswith(prefix_sample_file)]


def calculate_files_per_new_process(missing_sample_files, free_process_slots):
    nb_new_processes = min(missing_sample_files, free_process_slots)
    files_per_new_process = [missing_sample_files / nb_new_processes
                             for _ in range(nb_new_processes)]
    for idx in range(len(files_per_new_process) - 1):
        files_per_new_process[idx + 1] += (
                files_per_new_process[idx] -
                math.floor(files_per_new_process[idx]))
        files_per_new_process[idx] = int((files_per_new_process[idx]))
    files_per_new_process[-1] = int(files_per_new_process[-1])
    assert all(x > 0 for x in files_per_new_process), files_per_new_process
    return files_per_new_process


def output_filtered_stderr(processes, only_stopped=True):
    for p in processes:
        if not only_stopped or p.poll() is not None:
            err = p.stderr.read()
            for pattern in [
                    apt.PATTERN_TENSORFLOW_WARNING,
                    apt.PATTERN_USING_TENSORFLOW,
                    apt.PATTERN_HDF5_CONVERSION_WARNING]:
                err = pattern.sub("", err)
            err = err.strip()
            if len(err) > 0:
                print(err)


def get_seed_dictionary(nb_seeds, basis=None):
    return {"seed%i" %i:
                ((basis + i) if basis is not None
                 else random.randint(0,2147483647))
            for i in range(nb_seeds)}

def replace_seeds(x, seeds):
    for seed, value in seeds.items():
        x = x.replace("{%s}" % seed, str(value))
    return x


def construct_sampling_evaluator(
        main_evaluator, cold_start_evaluator, cold_start_weight, seeds):
    if cold_start_weight is None or cold_start_weight == 0:
        e = main_evaluator
    elif cold_start_weight == 1:
        e = cold_start_evaluator
    else:
        main_weight = 1 - cold_start_weight
        e = "sum([weight({main_evaluator}, {main_weight})," \
               "weight({cold_start_evaluator},{cold_start_weight})])".format(
            **locals())
    return replace_seeds(e, seeds)


def sample(recv_queue, send_queue, sync_err,
           script_fast_downward, fast_downward_build,
           file_sas, dir_data, consecutive_seeds, network_predefinition,
           evaluator, cold_start,
           template_sampling_engine, sampling_techniques,
           expand_goal,
           wrap_partial_assignment,
           lookahead, sample_cache_size, maximum_data_files,
           max_parallel_processes):
    def print_sampling(msg):
        print("Sampling>", msg)

    print_sampling("Started. Waiting for model creation.")
    next_seed = 0

    next_sample_file_index = 0
    await_token(recv_queue, SyncToken.ModelCreated)
    print_sampling("Model created. Start sampling")
    cold_start_evaluator = (None if cold_start is None
                            else cold_start[IDX_COLD_START_EVALUATOR])
    cold_start_init = (None if cold_start is None
                       else cold_start[IDX_COLD_START_INIT])
    cold_start_min = (None if cold_start is None
                       else cold_start[IDX_COLD_START_MIN])
    cold_start_weight = cold_start_init
    print_sampling("Increased max scrambles to %s" % ", ".join(
        str(st.max_scrambles) for st in sampling_techniques))
    fnull = open(os.devnull, "w")
    processes = []

    def kill_sampling_processes():
        for p in processes:
            if p.poll() is None:
                tm.kill_process_and_children(p.pid)

    valid_tokens = set([SyncToken.Terminate,
                        SyncToken.IncreaseScrambling,
                        SyncToken.InternalScramblingIncrement,
                        SyncToken.DecreaseColdStartWeight])
    while True:
        tokens = receive_tokens(
            recv_queue, *valid_tokens)
        if len(tokens) > 0:
            # Till now tokens do not need to be processed in fixed order
            sorted_tokens = defaultdict(list)
            for t in tokens:
                sorted_tokens[t[0]].append(t)

            # check if an invalid token was received
            invalid_tokens = set(sorted_tokens.keys()) - valid_tokens
            if len(invalid_tokens) > 0:
                print_sampling("Error: Invalid token received in sampling "
                               "process. Terminate.")
                err_msg = ", ".join(
                    [", ".join(str(x) for x in sorted_tokens[it])
                     for it in invalid_tokens])
                sorted_tokens[SyncToken.Error].append(
                    (SyncToken.Error, err_msg))
            # sync error to main process
            if SyncToken.Error in sorted_tokens:
                print_sampling("Initiate shutdown.")
                kill_sampling_processes()
                output_filtered_stderr(processes, only_stopped=False)
                sync_err.put(
                    (os.getpid(), "sample_invalid_token",
                     KeyError("SyncToken.Error"),
                     "\n---------------\n".join(
                         x[1] for x in sorted_tokens[SyncToken.Error])))
                break
            # Terminate
            if SyncToken.Terminate in sorted_tokens:
                print_sampling("Initiate shutdown.")
                kill_sampling_processes()
                output_filtered_stderr(processes, only_stopped=False)
                break
            # Reduce Coldstart weight
            if SyncToken.DecreaseColdStartWeight in sorted_tokens:
                max_decrease = max(
                    x[1] for x in sorted_tokens[SyncToken.DecreaseColdStartWeight])

                candidate_cold_start_weight = max(
                    cold_start_weight - max_decrease, cold_start_min)
                if candidate_cold_start_weight != cold_start_weight:
                    cold_start_weight = candidate_cold_start_weight
                    kill_sampling_processes()
                    print_sampling("Cold start weight decreased to %s" %
                                   cold_start_weight)
            # Increase scrambling
            did_increase_sampling = False
            for t in (sorted_tokens[SyncToken.InternalScramblingIncrement] +
                      sorted_tokens[SyncToken.IncreaseScrambling]):
                did_increase_sampling = True
                assert 0 <= t[1] < len(sampling_techniques), (t[1], len(sampling_techniques))
                st = sampling_techniques[t[1]]
                old_max_scrambles = st.max_scrambles
                if t[2][0] == "+":
                    st.max_scrambles += t[2][1]
                elif t[2][0] == "-":
                    st.max_scrambles -= t[2][1]
                elif t[2][0] == "*":
                    st.max_scrambles *= t[2][1]
                elif t[2][0] == "/":
                    st.max_scrambles /= t[2][1]
                elif t[2][0] == "=":
                    st.max_scrambles = t[2][1]
                else:
                    assert False, t
                if (t[0] == SyncToken.InternalScramblingIncrement and
                        st.upgrade is not None and
                        st.max_scrambles > old_max_scrambles):
                    st.upgrade[0] = max(0, st.upgrade[0] - 1)
            if len(sorted_tokens[SyncToken.IncreaseScrambling]):
                kill_sampling_processes()
            if did_increase_sampling:
                print_sampling("Increased max scrambles to %s" % ", ".join(
                    str(st.max_scrambles) for st in sampling_techniques))


        output_filtered_stderr(processes, only_stopped=True)
        processes = [x for x in processes if x.poll() is None]  # filter alive
        missing_sample_files = maximum_data_files - len(get_sample_files(
            dir_data, SAMPLE_FILE_PREFIX))

        if missing_sample_files > 0 and len(processes) < max_parallel_processes:
            files_per_new_process = calculate_files_per_new_process(
                missing_sample_files, max_parallel_processes - len(processes))

            # Determine active sampling techniques and their weighting
            samptech_active = [st.active is None or st.active.check_condition()
                               for st in sampling_techniques]
            sum_sampling_technique_weights = float(
                sum(st.weight
                    for no_st, st in enumerate(sampling_techniques)
                    if samptech_active[no_st]))
            samptech_weights = [(st.weight/sum_sampling_technique_weights
                                 if samptech_active[no_st] else 0)
                                for no_st, st in enumerate(sampling_techniques)]
            for nb_new_files in files_per_new_process:
                # Remember: The user is not forced to use any of those
                # formatters.
                seeds = get_seed_dictionary(
                    NB_SEEDS, next_seed if consecutive_seeds else None)
                next_seed += NB_SEEDS

                # Format sampling technique templates for next round
                techniques = []
                for no_st, st in enumerate(sampling_techniques):
                    if samptech_weights[no_st] <= 0:
                        continue
                    additional_parameters = {
                        "task_count": (int(sample_cache_size * nb_new_files *
                                           samptech_weights[no_st])),
                        "max_scrambles":st.max_scrambles,
                        "wrap_partial_assignment":apt.to_cpp_bool(
                            wrap_partial_assignment),
                        "bias": st.bias[0],
                        "bias_mode_probabilistic": st.bias[1],
                        "bias_reload_frequency": st.bias[2],
                        "bias_adapt": st.bias[3],
                    }
                    additional_parameters.update(seeds)
                    if st.upgrade is not None:
                        additional_parameters.update(st.upgrade[1])
                        additional_parameters["nb_upgrades"] = st.upgrade[0]
                    techniques.append(tm.format_with_defaults(
                        st.sampling_technique, (),additional_parameters))

                search_configuration = template_sampling_engine.format(
                    evaluator=construct_sampling_evaluator(
                        evaluator, cold_start_evaluator, cold_start_weight,
                        seeds),
                    lookahead=lookahead,
                    cache_size=sample_cache_size,
                    sampling_techniques=",".join(techniques),
                    cache_index=next_sample_file_index,
                    max_sample_files=nb_new_files,
                    TASK_TRANSFORMATION="sampling_transform()",
                    skip_undefined_facts=apt.to_cpp_bool(wrap_partial_assignment),
                    expand_goal=expand_goal[0],
                    reload_expanded_goals=expand_goal[1],
                    expand_goal_state_limit=expand_goal[2],
                    **seeds
                )

                command = (
                        [script_fast_downward,
                         "--build={}".format(fast_downward_build),
                         "--plan-file",
                         os.path.join(dir_data, SAMPLE_FILE_PREFIX),
                         file_sas] +
                        [replace_seeds(x, seeds)
                         for x in network_predefinition] +
                        ["--search", search_configuration])

                p = subprocess.Popen(command,
                                     stdout=fnull,
                                     stderr=subprocess.PIPE,
                                     universal_newlines=True)
                processes.append(p)
                send_queue.put(p.pid)
                next_sample_file_index += nb_new_files
        else:
            time.sleep(TIME_SLEEP_SAMPLING)
    fnull.close()
    print_sampling("Stop.")


""" ------------------------ Loading Methods --------------------------------"""


def load_samples_supervisor(queue_sample_files, dir_data):
    submitted_files = set()
    while True:
        files = set(get_sample_files(dir_data, SAMPLE_FILE_PREFIX))
        submitted_files = submitted_files & files
        new_files = files - submitted_files

        for sample_file in new_files:
            queue_sample_files.put(sample_file)
        submitted_files.update(new_files)
        if len(new_files) == 0:
            time.sleep(5)


def load_samples(recv_files, send_samples, sync_to_sampling,
                 err_dir, sync_err, file_sas=None,
                 has_unsolvable_samples=False,
                 label_transformation=None):
    sas = None if file_sas is None else apt.SAS(file_sas)
    running = True
    iteration = 0
    e_buffer = []
    e_buffer_length, e_buffer_lookback = 10, 50

    while running:
        sample_file = recv_files.get()  # Receive next file
        iteration += 1
        try:
            # Load data
            with open(sample_file, "r") as f:
                content = f.read()
            ary = np.loadtxt(
                StringIO(content), dtype=float, comments="#", delimiter=";")
            if len(ary.shape) == 1:
                ary = np.expand_dims(ary, 0)
            if len(ary) > 0:
                x, y, u = ary[:, 1:], ary[:, 0], None
                if has_unsolvable_samples:
                    x, u = x[:, 1:], x[:, 0]
                if sas is not None:
                    x = sas.convert_strips_encodings_to_sas(x)
                if label_transformation is not None:
                    y = label_transformation(y)
                send_samples.put((x, y, u))
            tpa.remove(sample_file, missing_ok=True)

            # Read comments if they announce an upgrade in the sampling
            # technique parameters
            sampling_technique_upgrades = [
                y for y in
                [REGEX_SAMPLING_SEARCH_FEEDBACK.match(x)
                 for x in content.splitlines()]
                if y is not None
            ]
            max_sampling_technique_upgrade = defaultdict(int)
            for stu in sampling_technique_upgrades:
                idx_st, new_max_scrambles = [int(x) for x in stu.groups()]
                max_sampling_technique_upgrade[idx_st] = max(
                    new_max_scrambles, max_sampling_technique_upgrade[idx_st])
            for idx_st, new_max_scrambles in max_sampling_technique_upgrade.items():
                sync_to_sampling.put((SyncToken.InternalScramblingIncrement,
                                     idx_st, ("=", new_max_scrambles)))

        except (IndexError, ValueError) as e:
            err_file = os.path.join(err_dir, os.path.basename(sample_file))
            tpa.move(sample_file, err_file, create_parents=True)
            e_buffer.append((e, traceback.format_exc(), err_file, iteration))
            if len(e_buffer) >= e_buffer_length:
                for _, _, err_file, _ in e_buffer[:-e_buffer_length]:
                    tpa.remove(err_file)
                e_buffer = e_buffer[-e_buffer_length:]
                if e_buffer[0][3] >= iteration - e_buffer_lookback:
                    sync_err.put(
                        (os.getpid(), "load_samples", e_buffer[-1][0],
                         "\n---------------\n".join(
                             "File: {}\nTraceback: {}".format(x[1], x[2])
                             for x in e_buffer)))
                    running = False


""" ------------------------ Search Methods ---------------------------------"""


def get_run_evaluation(script_fast_downward, fast_downward_build, start_time,
                       time_limit,
                       network_predefinition,
                       search_configuration, file_sas,
                       min_start_time_shift=0, log_file=None,
                       solve_only_once=True):
    def run_evaluation():
        # if 1. Another search is still running
        if ((run_evaluation.process is not None and
             run_evaluation.process.is_alive()) or
                # or 2. a previous search has solved the task
                (solve_only_once and run_evaluation.outcome is not None and
                 (run_evaluation.outcome.status == SyncToken.Solved or
                   run_evaluation.outcome.status ==
                  SyncToken.SolvedAndTerminate)) or
                # or 3. there has to be a larger break between two searches
                (run_evaluation.last_start_time is not None and
                 time.time() - run_evaluation.last_start_time <
                 min_start_time_shift)):
            return

        def run_process(cmd, data_out):
            try:
                str_start_time = time.strftime('%Y-%m-%d %H:%M:%S')
                out = tm.run_fast_downward(cmd)
                out = "Search start timestamp: {}\n{}\n" \
                      "Time since training begin: {}".format(
                    str_start_time, out, time.time() - start_time)

                if log_file is not None:
                    dir_parent = os.path.dirname(log_file)
                    if dir_parent != "":
                        tpa.make_dirs(dir_parent, exist_ok=True)

                    if solve_only_once:
                        with open(log_file, "w") as f:
                            f.write(out)
                    else:
                        with zipfile.ZipFile(log_file, "a") as f:
                            f.writestr("{}{}".format("run", run_evaluation.counter), out)

                data_out.put(((SyncToken.SolvedAndTerminate
                               if log_file is None else SyncToken.Solved)
                              if out.find("Solution found.") > -1 else
                              SyncToken.Unsolved, out))
            except subprocess.CalledProcessError as e:
                data_out.put((SyncToken.Error, e.message))
        seeds = get_seed_dictionary(NB_SEEDS)
        cmd_eval = ([
            script_fast_downward,
            "--build={}".format(fast_downward_build),
            "--overall-time-limit", "%is" % time_limit,
            file_sas] + [replace_seeds(x, seeds)
                         for x in network_predefinition] +
            ["--search", replace_seeds(search_configuration, seeds)])

        q = Queue()
        run_evaluation.outcome = EvaluationOutputWrapper(q)
        run_evaluation.process = Process(
            target=run_process, args=(cmd_eval, q))
        run_evaluation.last_start_time = time.time()
        run_evaluation.counter += 1
        run_evaluation.process.start()

    run_evaluation.process = None
    run_evaluation.last_start_time = time.time()
    run_evaluation.outcome = None
    run_evaluation.counter = 0
    return run_evaluation


""" ----------------------- Training Methods --------------------------------"""


def get_wrap_partial_assignment_postprocessor(file_sas, repetitions):
    sas = apt.SAS(file_sas)
    domain_sizes = sas.domains
    domain_intervals = [0]
    for x in domain_sizes:
        domain_intervals.append(domain_intervals[-1] + x)

    def data_postprocessor(x_data, y_data, sample_weights):
        assert len(sas.domains) == x_data.shape[1]
        instantiations = []

        while len(instantiations) < repetitions:
            x_inst = x_data if repetitions == 1 else np.copy(x_data)
            x_inst = sas.complete_sas_states(x_inst)
            x_inst = sas.convert_sas_encodings_to_strip(x_inst)
            instantiations.append(x_inst)

        y_shape = list(y_data.shape)
        y_shape[0] *= len(instantiations)
        y_data = np.tile(y_data, len(instantiations))
        y_data.shape = y_shape

        if sample_weights is not None:
            weights_shape = list(sample_weights.shape)
            weights_shape[0] *= len(instantiations)
            sample_weights = np.tile(sample_weights, len(instantiations))
            sample_weights.shape = weights_shape
        return np.vstack(instantiations), y_data, sample_weights
    return data_postprocessor


# def get_increment_scrambling_callback(
#         sync_queue, keras_callback, consecutive_sat, increment):
#     n_buffer = 100
#     n_avg = 50
#     assert 2 * n_avg <= n_buffer
#     sufficiently_close = 0.2
#     not_forget_probability = 0.99999995
#
#     def _increment_scrambling(is_satisfied, *_args, **kwargs):
#         log = kwargs["logs"]
#         _keras_callback = kwargs["keras_callback"]
#
#         _keras_callback._flags_active = False
#
#         _increment_scrambling.last_loss.append(log["loss"])
#         if len(_increment_scrambling.last_loss) > n_buffer:
#             _increment_scrambling.last_loss = \
#                 _increment_scrambling.last_loss[-n_buffer:]
#
#         _increment_scrambling.predictions.append(
#             _increment_scrambling.last_max_prediction)
#         if len(_increment_scrambling.predictions) > n_buffer:
#             _increment_scrambling.predictions = \
#                 _increment_scrambling.predictions[-n_buffer:]
#
#         if (_increment_scrambling.last_max_prediction >
#                 _increment_scrambling.max_prediction or
#                 random.random() >= not_forget_probability **
#                 _increment_scrambling.max_prediction_age):
#             _increment_scrambling.max_prediction = \
#                 _increment_scrambling.last_max_prediction
#             _increment_scrambling.max_prediction_age = 0
#         else:
#             _increment_scrambling.max_prediction_age += 1
#
#         if len(_increment_scrambling.predictions) < n_buffer:
#             return
#
#         max_old = np.average(
#             _increment_scrambling.predictions[- 2 * n_avg:n_avg])
#         max_new = np.average(
#             _increment_scrambling.predictions[-n_avg:])
#         loss_old = np.average(
#             _increment_scrambling.last_loss[- 2 * n_avg: -n_avg])
#         loss_new = np.average(_increment_scrambling.last_loss[-n_avg:])
#         # Max is still rising
#         if max_new - sufficiently_close > max_old:
#             _increment_scrambling.consecutive_max_prediction = 0
#         elif max_new + sufficiently_close < max_old:
#             _increment_scrambling.consecutive_max_prediction = 0
#         else:
#             _increment_scrambling.consecutive_max_prediction += 1
#
#         if (max_new >= 1 and
#                 max_new >
#                 _increment_scrambling.max_prediction - sufficiently_close and
#                 _increment_scrambling.consecutive_max_prediction >= n_buffer and
#                 loss_old - loss_new < 0.1):
#             sync_queue.put((SyncToken.IncreaseScrambling, increment))
#             _keras_callback._flags_active = True
#             _increment_scrambling.consecutive_max_prediction = 0
#             _increment_scrambling.prev_scramble_max = (max_new +
#                                                        sufficiently_close)
#             _increment_scrambling.curr_scramble += increment
#
#     _increment_scrambling.predictions = []
#     _increment_scrambling.last_max_prediction = 0
#     _increment_scrambling.max_prediction = 0
#     _increment_scrambling.max_prediction_age = 0
#     _increment_scrambling.last_max_label = 0
#     _increment_scrambling.consecutive_max_prediction = 0
#
#     _increment_scrambling.last_loss = []
#     _increment_scrambling.prev_scramble_max = 1
#     _increment_scrambling.curr_scramble = increment
#
#     def callback_predictions(labels, predictions):
#         _increment_scrambling.last_max_prediction = np.max(predictions)
#         _increment_scrambling.last_max_label = np.max(labels)
#
#     keras_callback.callback = _increment_scrambling
#     return callback_predictions


def analyse_history(dir_out, history):
    colors = ['r', 'g', 'b', 'm', 'c', 'k', 'y']
    max_samples = 2000
    l = len(history["max_inputs"])
    adapt_epochs = l > max_samples
    scale = 1
    if adapt_epochs:
        scale = max_samples / float(l)
        for data_name in ["max_inputs", "max_predicted", "loss"]:
            d = history[data_name]
            history[data_name] = [d[int(l / max_samples * i)]
                                  for i in range(max_samples)]
        for flag_name, flag_occurrences in history["flags"].items():
            history["flags"][flag_name] = [x * max_samples / l
                                           for x in flag_occurrences]

    idx_color = 0
    fig = plt.figure()
    ax = fig.add_subplot(111)
    tax = ax.twinx()
    ax.plot(history["max_inputs"], label="max sample values",
            color=colors[idx_color])
    idx_color += 1
    ax.plot(history["max_predicted"], label="max predicted",
            color=colors[idx_color])
    idx_color += 1
    tax.plot(history["loss"], label="loss", color=colors[idx_color])
    idx_color += 1
    for flag_name, flag_occurrences in history["flags"].items():
        if flag_name in ["model_saved"]:
            continue
        for no, x in enumerate(history["flags"][flag_name]):
            ax.axvline(x, linewidth=1, label=flag_name if no == 0 else None,
                       color=colors[idx_color])
        idx_color += 1
    tax.legend()
    ax.legend()
    ax.set_ylabel("Max V value")
    tax.set_ylabel("loss")
    ax.set_xlabel("Epochs x {}".format(1/scale))

    fig.savefig(os.path.join(dir_out, "evolution.pdf"))


def get_send_increase_token(no, sync_to_sampling, increment):
    def _send_increase_token(is_satisfied, *_args, **_kwargs):
        if is_satisfied:
            sync_to_sampling.put(
                (SyncToken.IncreaseScrambling,
                 _send_increase_token._no,
                 increment))
    _send_increase_token._no = no
    return _send_increase_token

def train(sync_to_sampling, network, load_initial_model,
          dir_original, file_sas, dir_data, recv_samples, max_time,
          sampling_techniques, reinitialize=None, data_postprocessor=None,
          run_initial_evaluation=None,
          run_intermediate_evaluation=None,
          callback_cold_start=None,
          raise_child_error=None,
          buffer_size_factor=20,
          ):

    def print_training(msg):
        print("Training>", msg)

    # Create initial model (or load)
    state_size = sum(apt.SAS(file_sas).domains)
    network.initialize(None, state_size=state_size,
                       skip_loading=load_initial_model is None)
    network.store()
    print_training("Model initialized and saved. Sending token...")
    sync_to_sampling.put(SyncToken.ModelCreated)
    print_training("Token sent. Training...")

    """Setup additional callbacks"""
    callbacks = [] # normal keras callbacks
    callback_predictions = []  # Receives the predictions on the training data

    if callback_cold_start is not None:
        callbacks.append(callback_cold_start)

    for no, sampling_technique in enumerate(sampling_techniques):
        if sampling_technique.increase_scrambling is not None:
            sampling_technique.increase_scrambling[0].callback = \
                get_send_increase_token(
                    no, sync_to_sampling,
                    sampling_technique.increase_scrambling[1])
            callbacks.append(sampling_technique.increase_scrambling[0])
    if reinitialize is not None:
        def reinitialize_network(labels, predictions):
            if not reinitialize_network.active:
                return

            reinitialize_network.buffer.append(np.max(predictions))
            if len(reinitialize_network.buffer) >= 10:
                reinitialize_network.buffer = reinitialize_network.buffer[-10:]
                if time.time() - reinitialize_network.time > reinitialize:
                    if np.max(reinitialize_network.buffer) < 1.0:
                        network.reinitialize(skip_loading=True)
                        network.store()
                        reinitialize_network.time = time.time()
                        reinitialize_network.buffer = []
                    else:
                        reinitialize_network.active = False
        reinitialize_network.time = time.time()
        reinitialize_network.buffer = []
        reinitialize_network.active = True
        callback_predictions.append(reinitialize_network)

    # functions that are executed after each epoch if the defined flag was raised
    # 'always' is always raised.
    callback_flags = []
    if raise_child_error is not None:
        callback_flags.append(("always", raise_child_error))
    if (run_initial_evaluation is not None or
            run_intermediate_evaluation is not None):
        def stop_training():
            if any(r.outcome is not None and
                   r.outcome.status == SyncToken.SolvedAndTerminate
                   for r in [run_initial_evaluation,
                             run_intermediate_evaluation]
                   if r is not None):
                raise StopTraining()
        callback_flags.append(("always", run_intermediate_evaluation))
        callback_flags.append(("always", stop_training))

    # Log disk space
    def print_disk_space():
        print_disk_space.counter += 1
        if print_disk_space.counter > 10000:
            disk_usage = psutil.disk_usage(dir_data)
            print("Disk usage of %s: %.1f%% of %i GB" % (
                os.path.abspath("."), disk_usage[3],
                disk_usage[0] / (1024 ** 3)))
            disk_space_data = 0
            files_data = os.listdir(dir_data)
            for x in files_data:
                try:
                    disk_space_data = os.path.getsize(os.path.join(dir_data, x))
                    break
                except FileNotFoundError:
                    pass
            print("Approximate disk space for samples: %.2f GB" % (disk_space_data*len(files_data)/1024**3))

            print_disk_space.counter = 0
    print_disk_space.counter = 0
    callback_flags.append(("always", print_disk_space))
    # Run initial evaluation in PARALLEL!
    if run_initial_evaluation is not None:
        run_initial_evaluation()

    history = network.train_reinforcement(
        data_queue=recv_samples,
        max_time=max_time,
        additional_callbacks=callbacks,
        callback_predictions=callback_predictions,
        callback_flags=callback_flags,
        data_postprocessor=data_postprocessor,
        replay_buffer_size_factor=buffer_size_factor,
        start_epoch=0 if load_initial_model is None else load_initial_model[0],
        start_time=0 if load_initial_model is None else load_initial_model[1]
    )
    print_training("Training ended.")
    sync_to_sampling.put(SyncToken.Terminate)
    analyse_history(dir_original, history)


def configure_options(options):
    learner = options.learner
    format_args = {"seed%i" %i: "{seed%i}" %i for i in range(NB_SEEDS)}
    format_args.update({k: v for k,v in options.transform_label[1:]})
    if "exponentiate_heuristic" not in format_args:
        format_args["exponentiate_heuristic"] = "false"
    # Remark: The network definition can also be in the evaluator body
    def format_network(training_phase, template_string, network_definition=None):
        assert (template_string.find("NETWORK_PRE_DEFINITION") == -1 or
                network_definition is not None)
        return template_string.format(
            OUTPUT_TYPE=("regression" if learner.output_units == -1
                         else "classification"),
            UNARY_THRESHOLD=(learner.ordinal_classification_threshold
                             if learner.ordinal_classification else 0),
            MODEL_FILE=learner.path_store,
            TASK_TRANSFORMATION="sampling_transform()",
            BIN_SIZE=learner.bin_size,
            MODEL_OUTPUT_LAYER=("{MODEL_OUTPUT_LAYER,%s.pb}" %
                                learner.path_store),
            domain_max_is_undefined=(
                apt.to_cpp_bool((options.wrap_partial_assignment is not None)
                                if training_phase else False)),
            NETWORK_PRE_DEFINITION=network_definition,
            **format_args
        )
    # Format possible network definition
    options.fd_network_definition = list(options.fd_network_definition)
    fd_nn = options.fd_network_definition
    options.fd_network_definition[0] = (
        format_network(True, fd_nn[0]),
        format_network(False, fd_nn[0]))

    network_training_key = fd_nn[-1] if len(fd_nn) > 1 else fd_nn[0][0]
    network_search_key = fd_nn[-1] if len(fd_nn) > 1 else fd_nn[0][1]

    if isinstance(learner, KerasMLP):
        options.v_value_evaluator = format_network(
            True, options.v_value_evaluator, network_training_key)
        if options.cold_start_evaluator is not None:
            options.cold_start_evaluator[IDX_COLD_START_EVALUATOR] = (
                format_network(
                    True, options.cold_start_evaluator[IDX_COLD_START_EVALUATOR],
                    network_training_key)
            )
    else:
        assert False, "Unknown learner type."

    if options.add_final_evaluation is None:
        options.add_final_evaluation = TEMPLATE_SEARCH_CONFIGURATION
    if options.add_final_evaluation is not False:
        options.add_final_evaluation = options.add_final_evaluation.format(
            evaluator=format_network(
                False, options.v_value_evaluator, network_search_key))

    for arg in ["add_initial_evaluation", "add_intermediate_evaluations"]:
        template = getattr(options, arg)
        if template is not False:
            setattr(options, arg, [template[0].format(
                evaluator=format_network(
                    False, options.v_value_evaluator, network_search_key))] +
                    template[1:])

    return options


def run(options):
    run_start_time = time.time()
    configure_options(options)
    dir_data = os.path.join(options.working_directory, "data")
    file_sas = os.path.join(options.working_directory, "output.sas")
    file_used_facts = os.path.abspath("used_atoms.json")
    script_fast_downward = os.path.join(
        options.fast_downward_directory, "fast-downward.py")

    sync_to_sampling = Queue()
    sync_sampling_process_ids = Queue()
    sync_err_to_main = Queue()  # (pid, name, exception,  traceback string)

    def raise_child_error():
        if not sync_err_to_main.empty():
            pid, state, e, tb_s = sync_err_to_main.get()
            print("Exception in Child Process: {state} ({pid})".format(
                **locals()), file=sys.stderr)
            print(tb_s, file=sys.stderr)
            sys.stderr.flush()
            raise e

    pass_sample_files = Queue(100)
    pass_samples = Queue(options.max_cached_batches)

    original_directory = os.getcwd()
    setup(options.working_directory, script_fast_downward, options.domain_pddl,
          options.task_pddl, file_sas, file_used_facts,
          dir_data, options.fast_downward_build,
          options.translator_options, options.preprocessor,
          options.add_initial_evaluation, options.add_intermediate_evaluations)

    network_predefinition_training = (
        [] if len(options.fd_network_definition) == 1
        else ["--network", "%s=%s" % (options.fd_network_definition[1],
                                      options.fd_network_definition[0][0])])
    network_predefinition_search = ([] if len(options.fd_network_definition) == 1
         else ["--network", "%s=%s" % (options.fd_network_definition[1],
                                       options.fd_network_definition[0][1])])
    # Sampling Processes
    proc_sampling = Process(target=sample, args=(
        sync_to_sampling, sync_sampling_process_ids, sync_err_to_main,
        script_fast_downward, options.fast_downward_build,
        file_sas, dir_data,
        not options.random_seeds,
        network_predefinition_training,
        options.v_value_evaluator,
        options.cold_start_evaluator,
        options.sampling_engine, options.sampling_technique,
        options.expand_goal,
        options.wrap_partial_assignment is not None,
        options.lookahead, options.samples_per_data_file,
        options.max_data_files,
        options.max_generator_processes))
    proc_sampling.start()

    # Sampling Loading Supervisor Process
    proc_loading_supervisor = Process(target=load_samples_supervisor, args=(
        pass_sample_files, dir_data))
    proc_loading_supervisor.start()

    # Sampling Load Processes
    procs_loading = []
    for _ in range(1):
        procs_loading.append(Process(target=load_samples, args=(
            pass_sample_files, pass_samples, sync_to_sampling,
            os.path.join(os.path.dirname(dir_data), "err_dir"),
            sync_err_to_main,
            None if options.wrap_partial_assignment is None else file_sas,
            options.samples_have_timeout_information,
            options.transform_label[0]
        )))
        procs_loading[-1].start()

    run_intermediate_evaluation = None
    if options.add_intermediate_evaluations is not False:
        aie = options.add_intermediate_evaluations
        log_file = None if len(aie) == 3 else aie[3]
        run_intermediate_evaluation = get_run_evaluation(
            script_fast_downward, options.fast_downward_build, run_start_time,
            aie[2], network_predefinition_search, aie[0],
            file_sas, aie[1], log_file=log_file,
            solve_only_once=log_file is None
        )

    run_initial_evaluation = None
    if options.add_initial_evaluation is not False:
        aie = options.add_initial_evaluation
        run_initial_evaluation = get_run_evaluation(
            script_fast_downward, options.fast_downward_build, run_start_time,
            aie[1], [], aie[0],
            file_sas, log_file=None if len(aie) == 2 else aie[2],
            solve_only_once=True
        )

    callback_cold_start = None
    if options.cold_start_evaluator is not None:
        def func_callback_cold_start(is_satisfied, **kwargs):
            if is_satisfied:
                if func_callback_cold_start.first_time:
                    func_callback_cold_start.first_time = False
                    return

                sync_to_sampling.put(
                    (SyncToken.DecreaseColdStartWeight,
                     options.cold_start_evaluator[IDX_COLD_START_STEP],
                     options.cold_start_evaluator[IDX_COLD_START_INIT],
                     options.cold_start_evaluator[IDX_COLD_START_MIN]))
        func_callback_cold_start.first_time = True
        callback_cold_start = options.cold_start_evaluator[
            IDX_COLD_START_CALLBACK]
        callback_cold_start.callback = func_callback_cold_start
    def exec_shutdown():
        with exec_shutdown.lock:
            if exec_shutdown.executed:
                return
            exec_shutdown.executed = True
        shutdown(sync_to_sampling, sync_sampling_process_ids,
                 proc_sampling, proc_loading_supervisor, procs_loading,
                 (None if run_intermediate_evaluation is None else
                  run_intermediate_evaluation.process),
                 (None if run_initial_evaluation is None else
                  run_initial_evaluation.process),
                 options.working_directory, dir_data)
        exec_shutdown.executed = True
    exec_shutdown.executed = False
    exec_shutdown.lock = RLock()

    def shutdown_on_signal(_sig, _frame):
        exec_shutdown()
    signal.signal(signal.SIGINT, shutdown_on_signal)
    signal.signal(signal.SIGTERM, shutdown_on_signal)
    signal.signal(signal.SIGXCPU, shutdown_on_signal)

    # Start Training
    try:
        data_postprocessor = None
        if options.wrap_partial_assignment is not None:
            data_postprocessor = get_wrap_partial_assignment_postprocessor(
                file_sas, options.wrap_partial_assignment)

        train(sync_to_sampling, options.learner, options.load_initial_model,
              original_directory, file_sas, dir_data, pass_samples,
              options.maximum_training_time, options.sampling_technique,
              reinitialize=options.reinitialize_after_time,
              data_postprocessor=data_postprocessor,
              run_initial_evaluation=run_initial_evaluation,
              run_intermediate_evaluation=run_intermediate_evaluation,
              callback_cold_start=callback_cold_start,
              raise_child_error=raise_child_error,
              buffer_size_factor=options.buffer_size_factor)
    except:
        exec_shutdown()
        raise

    exec_shutdown()

    is_solved = False
    for name, method in [("initial", run_initial_evaluation),
                         ("intermediate", run_intermediate_evaluation)]:
        if (method is not None and method.outcome is not None and
                (method.outcome.status == SyncToken.SolvedAndTerminate)):
            print("Solved by {name} evaluation".format(**locals()))
            print(method.outcome.message)
            is_solved = True
            break

    if not is_solved and options.add_final_evaluation is not False:
        seeds = get_seed_dictionary(NB_SEEDS)
        command = ([
            script_fast_downward,
            "--build={}".format(options.fast_downward_build),
            "--overall-time-limit", "30m",
            file_sas] +
            [replace_seeds(x, seeds) for x in network_predefinition_search] +
            ["--search", replace_seeds(options.add_final_evaluation, seeds)])
        print("FINAL EVAL COMMAND", command)
        out = tm.run_fast_downward(command)
        out = "{}\nTime since training begin: {}".format(
            out, time.time() - run_start_time)
        if out.find("Solution found.") > -1:
            print("Solved by final evaluation")
        print(out)

    print("Total experiment time: {0:.1f}s.".format(
        time.time() - run_start_time))
    print("Intermediate searches started: {}".format(
        0 if run_intermediate_evaluation is None else
        run_intermediate_evaluation.counter))
    print("RL Training and Evaluation finished.")

def update_callback(cb, init_epoch, init_time):
    if isinstance(cb, BaseKerasConditionExecutor):
        cb._initial_time -= init_time
        cb._iter += init_epoch
        cb._last_execute_time -= init_time
        cb._last_execute_epoch += init_epoch
        cb.check_condition()


def parse(argv):
    global NB_SEEDS
    options = parser.parse_args(argv)

    if options.load_initial_model is None:
        options.load_initial_model = (0, 0)
    assert len(options.load_initial_model) == 2, "invalid --load-initial-model"
    if options.load_initial_model == (-1, -1):
        options.load_initial_model = None
    assert options.load_initial_model is None or all(x >= 0 for x in options.load_initial_model), options.load_initial_model
    assert options.load_initial_model is None or options.reinitialize_after_time is None

    # Parse sampling techniques (and add default if missing)
    if len(options.sampling_technique) == 0:
        options.sampling_technique.append(
            [apt.has_seeds(TEMPLATE_GBACKWARD_NONE_TECHNIQUE)])
    for no, sampling_technique in enumerate(options.sampling_technique):
        options.sampling_technique[no] = parser_samptech.parse_args(
            sampling_technique)
        if options.sampling_technique[no].upgrade is not None:
            max_upgrades = int(options.sampling_technique[no].upgrade[0])
            upgrade_parameters = {"upgrade_%s" % k: v for k, v in
                [x.split("=") for x in options.sampling_technique[no].upgrade[1:]
            ]}
            options.sampling_technique[no].upgrade = [max_upgrades, upgrade_parameters]

    if options.load_initial_model is not None and any(x > 0 for x in options.load_initial_model):
        for cb in apt.callback.constructed_callbacks:
            update_callback(cb, *options.load_initial_model)
        assert isinstance(options.learner, KerasNetwork)
        for cb in options.learner.training_params.callbacks:
            update_callback(cb, *options.load_initial_model)

    NB_SEEDS = apt.has_seeds.count
    apt.check_all_arg_counts_arguments()
    if (options.wrap_partial_assignment is not None and
            options.sampling_engine.find("sampling_search(") > -1):
        print("sampling_search engine does not support the option "
              "'--wrap-partial-assignment'", file=sys.stderr)
        sys.exit(1)

    if options.cold_start_evaluator is not None:
        options.cold_start_evaluator = list(options.cold_start_evaluator)
        for n, v in [(IDX_COLD_START_INIT, 1.0), (IDX_COLD_START_MIN, 0.0)]:
            if len(options.cold_start_evaluator) <= n:
                options.cold_start_evaluator.append(v)
        assert len(options.cold_start_evaluator) == 5

    options.translator_options = [y
                                  for x in options.translator_options
                                  for y in shlex.split(x)]
    if options.sampling_engine.find("add_unsolved_samples=true") > -1:
        options.samples_have_timeout_information = True
    return options


if __name__ == "__main__":
    run(parse(sys.argv[1:]))
