#! /usr/bin/env python
# -*- coding: utf-8 -*-
"""
"keras_dep_mlp(tparams=ktparams(epochs=2,loss=mean_squared_error,batch=100,balance=False,optimizer=adam,callbacks=[keras_model_checkpoint(val_loss,./checkpoint.ckp),keras_progress_checking(acc,100,0.001,False),keras_early_stopping(val_loss,0.01,100),keras_restart(-1,stop_successful=True),keras_stoptimer(max_time=86400,per_training=False,prevent_reinit=True,timeout_as_failure=True)]),hidden=3,dependency=pre_post,dense_layers=3,output_units=-1,dropout=0,x_fields=[current_state,goals],y_fields=[hplan],formats=[hdf5,protobuf],graphdef=graphdef.txt,count_samples=True)" --prefix tmp_ -d ../DeePDown/data/FixedWorlds/opt/transport_var_roads/c10_t2_p2/ --input "gzip(suffix=.uniform.data.gz)" -o -n model --fields goals hplan current_state --skip --skip-if-running --skip-if-flag --skip-if-trained --maximum-data-memory 0.05GB -dp --format NonStatic_A_01
"keras_adp_mlp(tparams=ktparams(epochs=10,loss=mean_squared_error,batch=100,balance=False,optimizer=adam,callbacks=[keras_model_checkpoint(val_loss,./checkpoint.ckp),keras_progress_checking(val_loss,100,2,False,True),keras_early_stopping(val_loss,0.01,100),keras_restart(-1,stop_successful=True),keras_stoptimer(max_time=86400,per_training=False,prevent_reinit=True,timeout_as_failure=True)]),hidden=3,output_units=-2, ordinal_classification=true,bin_size=2,dropout=0,x_fields=[current_state,goals],y_fields=[hplan],formats=[hdf5,protobuf],graphdef=graphdef.txt,count_samples=True)" --prefix tmp_ -d ../DeePDown/data/FixedWorlds/opt/transport_var_roads/c10_t2_p2/ --input "gzip(suffix=.uniform.data.gz)" -o -n model --fields goals hplan current_state --skip --skip-if-running --maximum-data-memory 0.1GB -dp --format NonStatic_A_01
"keras_adp_mlp(tparams=ktparams(epochs=10,loss=mean_squared_error,batch=100,balance=False,optimizer=adam,callbacks=[keras_model_checkpoint(val_loss,./checkpoint.ckp),keras_progress_checking(val_loss,100,2,False,True),keras_early_stopping(val_loss,0.01,100),keras_stoptimer(max_time=86400,per_training=False,prevent_reinit=True,timeout_as_failure=True)]),hidden=2,residual_layers=[keras_residual_block(hidden_layer_count=2),keras_residual_block(hidden_layer_count=2)],output_units=-2,batch_normalization=1, ordinal_classification=true,bin_size=1,dropout=0,x_fields=[current_state,goals],y_fields=[hplan],learner_formats=[hdf5,protobuf],graphdef=graphdef.txt,count_samples=True)" --prefix tmp_3_fold_ -d ../DeePDown/data/FixedWorlds/opt/depot_fix_goals/depot_p05/ --input "gzip(suffix=.generator.plan.sat.data.gz)" -o -n model --fields goals hplan current_state --skip --skip-if-running --maximum-data-memory 0.1GB -dp --format NonStatic_A_01 --samples-total-training 1000

keras_mlp_encoder(encoder_hidden_layer_size=[-0.9], latent_space_size=-0.8, decoder_hidden_layer_size=[-0.9])
"keras_adp_mlp(tparams=ktparams(epochs=10,loss=mean_squared_error,batch=100,balance=True,optimizer=adam,callbacks=[keras_model_checkpoint(val_loss,./checkpoint.ckp),keras_progress_checking(val_loss,100,2,False,True),keras_early_stopping(val_loss,0.01,100),keras_restart(-1,stop_successful=True),keras_stoptimer(max_time=86400,per_training=False,prevent_reinit=True,timeout_as_failure=True)]),hidden=3,output_units=-2, ordinal_classification=true,bin_size=1,dropout=0,x_fields=[current_state,goals],y_fields=[hplan],formats=[hdf5,protobuf],graphdef=graphdef.txt,count_samples=True)" --prefix tmp_3_fold_ -d ../DeePDown/data/FixedWorlds/opt/depot_fix_goals/depot_p05/ --input "gzip(suffix=.generator.plan.sat.data.gz)" -o -n model --fields goals hplan current_state --skip --skip-if-running --maximum-data-memory 0.1GB -dp --format NonStatic_A_01 --samples-total-training 1000
"""
from __future__ import print_function

import disable_external_dependencies
SUPPRESS_LIBRARY_WARNINGS = False
stderr = disable_external_dependencies.suppress_library_warnings(
    SUPPRESS_LIBRARY_WARNINGS)

import tools
from tools import constants as tc
from tools import misc as tm
from tools import parsing as apt

from src.training.bridges import StateFormat, LoadSampleBridge
from src.training.bridges.sampling_bridges import MetaFields
from src.training.misc import DomainProperties
from src.training.misc import StreamContext
from src.training.learners import LearnerFormat
from src.training.samplers import DirectorySampler

import argparse
import collections
import datetime
from enum import Enum
import json
import matplotlib as mpl
mpl.use('agg')
import matplotlib.pyplot as plt
import numpy as np
import os
import psutil
import random
import re
import shlex
import shutil
import sys
if sys.version_info < (3,):
    import subprocess32 as subprocess


    def decoder(s):
        return s.decode()
else:
    import subprocess


    def decoder(s):
        return s
import time

disable_external_dependencies.unsuppress_library_warnings(stderr)


"""-------------------- Constants -------------------------------------------"""

REGEX_SUBMITTED_BATCH_JOB = re.compile(r"Submitted batch job (\d*)")
REGEX_FOLD = re.compile(r".*_(\d+)_fold_.*")

ARG_NEW_TRAINING = "--new-training"
SLURM_SUMMARIZE_KEY = "slurm_summarize"


class PruningTypes(Enum):
    Off = "off"
    Inter = "inter"
    Intra = "intra"
    IntraInter = "intra_inter"


class SampleRequirementException(Exception):
    pass


"""------------------------- Parsing Stuff ------------------------------"""


def _init_join_samples_types():
    return "\n\t".join(["%s:\t%s" % (st.name, st.description)
                        for st in tc.SampleTypes.types()])


def _init_get_pruning_types_list():
    return [x for x in PruningTypes]


restrict_total_split_data = apt.check_buffer(
    lambda: 0.0,
    lambda c, a, _: c + a,
    lambda x: apt.raise_value_error(x > 1.0, "Too much to split off")
)


ptrain = argparse.ArgumentParser(description="""Train network on previously 
sampled data. If no test data is given, then the validation data is used for 
the final evaluation. If neither validation data is given, the performance is 
evaluated on the training data. You can define additional trainings via adding 
%s and then adding the options for the next training""" % ARG_NEW_TRAINING)
ptrain_mutex_execute = ptrain.add_mutually_exclusive_group()
ptrain_mutex_dp = ptrain.add_mutually_exclusive_group()
ptrain_mutex_multiple_iterations = ptrain.add_mutually_exclusive_group()

ptrain.add_argument("network", type=apt.learner,
                    help="Definition of the network.")
ptrain.add_argument("-a", "--args", type=str, default=None,
                    help="Single string describing a set of arguments to add "
                         "in front of all arguments if calling another script "
                         "for training execution (see '--execute').")
ptrain_mutex_multiple_iterations.add_argument(
    "--cross-validation", default=None,
    type=apt.int_positive,
    help="Works only together with '--execute' (todo make this"
         "work without '--execute'). Adds '--array=0-N' as "
         "first argument into '--args'. Sorts the problems "
         "and splits"
         "them into N folds of (close to equal) size. Provides"
         "after the arguments of '--args' N arguments "
         "representing regular expressions identifying the "
         "problems for each of the N folds.")
ptrain.add_argument("-d", "--directory", type=apt.absdir,
                    nargs="+", action="append", default=[],
                    help="Path to a list of directories from which to load the"
                         " training data. This argument can be given multiple"
                         " times. The execution of this scrips equals then"
                         " calling this script with the same arguments for "
                         "each "
                         "(if not --sub-directory-training, then the domain "
                         "file"
                         "is required in the first given directory")
ptrain.add_argument("-df", "--directory-filter", type=re.compile,
                    action="append", default=[],
                    help="A subdirectory name has to match the regex otherwise"
                         "it is not traversed. By default no regex matches are"
                         "required. This argument can be given any number of"
                         "time to add additional filters (the directory name "
                         "has"
                         "to match ALL regexes)")
ptrain_mutex_dp.add_argument("-dp", "--domain-properties", action="store_true",
                             help="If set and the networks supports it, "
                                  "then the network"
                                  " is provided with an analysis of "
                                  "properties of the problems domain")
ptrain_mutex_dp.add_argument("-dpns", "--domain-properties-no-statics",
                             action="store_true",
                             help="If set and the networks supports it, "
                                  "then the network is provided with an "
                                  "analysis "
                                  "of properties of the problems domain. This "
                                  "domain properties object does not analyse "
                                  "the "
                                  "static groundings (and everything "
                                  "depending).")
ptrain.add_argument("--dry",
                    action="store_true",
                    help="Tells only which trainings it would perform, "
                         "but does "
                         "not perform the training step.")
ptrain_mutex_execute.add_argument(
    "-e", "--execute", type=apt.isfile, default=None,
    help="Path to script to execute for the training runs. If none is given, "
         "then this script is used, otherwise, it calls an external script in "
         "a subprocess and  passes its parameters")
ptrain.add_argument("--fields", type=str, nargs="+",  default=[],
                    help=("List all fields of the data which shall be"
                          "loaded in the order they shall appear(if the"
                          " order is relevant for you)"))
ptrain.add_argument("-fin", "--finalize", nargs="+", default=[],
                    type=apt.split_type("="),
                    help="List some key=value pairs which are passed "
                         "as key=value to the networks finalize method.")
ptrain.add_argument("--forget", type=apt.float_interval(0.0, 1.0),
                    default=0.0,
                    help=("Probability of skipping to load entries of the "
                          "validation data"))
ptrain.add_argument("-f", "--format", choices=StateFormat.get_formats(),
                    default=None, type=StateFormat.get,
                    help=("State format name into which the loaded data shall"
                          "be converted (if not given, the preferred of the"
                          "network is chosen)"))
ptrain.add_argument("-init", "--initialize", nargs="+", default=[],
                    type=apt.split_type("="),
                    help="List some key=value pairs which are passed "
                         "as key=value to the networks initialize method.")
ptrain.add_argument("-i", "--input", type=apt.stream_definition,
                    action="append", default=[], required=True,
                    help="Define an input stream for the loading of samples"
                         "(use this option multiple times for multiple). The "
                         " available streams can be checked in "
                         "training.misc.stream_contexts.py"
                         "(the way this is done is for every problem file of"
                         " which data shall be loaded the stream is asked,"
                         "where would you store data for this file and then"
                         "the data at the location is loaded).")
ptrain.add_argument("-l", "--load", type=str,
                    default=None,
                    help="Overrides the network load location defined in the "
                         "network definition by "
                         "'{network.path_out}/{--load}'")
ptrain.add_argument("--max-depth", default=None,
                    type=apt.restricted_type(
                        apt.named_type(apt.int_zero_positive,
                                      "max_depth"),
                        apt.check_min_max_restriction("min_depth", "max_depth")),
                    help="Maximum depth from the root which is traversed ("
                         "default has no maximum, 0 means traversing no"
                         "sub-folders, only the content of the root)")
ptrain.add_argument("--min-depth",  default=None,
                    type=apt.restricted_type(
                        apt.named_type(apt.int_zero_positive,
                                      "min_depth"),
                        apt.check_min_max_restriction("min_depth", "max_depth")),
                    help="Minimum depth from the root which has to be traversed"
                         " before problem files are registered (default has "
                         "no minimum)")
ptrain.add_argument("--maximum-data-memory",
                    type=apt.memory,
                    default=None,
                    help="Maximum memory to use for the data. Once this limit "
                         "is exhausted, not more data is loaded. Memory limit "
                         "is defined in KB unless defined otherwise via "
                         "suffices: KB, MB, GB")
ptrain.add_argument("--merge",  type=re.compile, default=None,
                    help="Regex describing all stored evaluations which shall"
                         "be combined. Most other options become useless, as no"
                         "training is performed afterwards. Network has to be "
                         "defined")
SAMPLE_RESTRICTION_HINT = (
    "\nThe value can be: value="
    "{multiply(:value)+} [MULTIPLIES THE GIVEN VALUES]|"
    "{job_stats:FILE:{train|valid|test}:(KEY=VALUE)+}|"
    "{restriction_file(:key)+} [LOADS THE sample_restrictions.json DICT OF "
    "THE MAIN PROBLEM DOMAIN AND RETURNS THE VALUE BEHIND THE GIVEN KEYS]|"
    "int [BASE VALUE]\n")

ptrain.add_argument("--global-minimum-samples-per-set", default=100,
                    type=apt.int_positive,
                    help="If any set has fewer than this number of samples,"
                         "then the training will be aborted, because of to"
                         "few samples")
ptrain.add_argument("--minimum-samples-training",  default=None,
                    type=apt.int_positive,
                    help="Minimum amount of samples to load for training "
                         "otherwise"
                         "the network will not be trained. This should be"
                         "save on pruning, but is not sufficiently tested,"
                         "this check happens before a final pruning run ("
                         "which should not change anything anymore, but I did "
                         "not debug sufficiently to be sure that I can delete "
                         "it) (this assumes that you data set contains exactly"
                         "the following data: for each problem solved the whole"
                         "solution trajectory. nothing more. %s" %
                         SAMPLE_RESTRICTION_HINT)
ptrain.add_argument("-n", "--name", type=str, default=None,
                    help="Sets the network store path to"
                         "'{network.path_out}/{--prefix}{--name}."
                         "{file suffix}'.See additionally '--output'")
ptrain.add_argument("-o", "--output", action="store_true",
                    help="overwrites the network.path_out directory specified "
                         "in the network definition with the first root "
                         "directory of the training data.")
ptrain.add_argument("--only-evaluate", action="store_true",
                    help="Does NOT train a network, therefore, a model to load "
                         "has to exist. Evaluates on the test data the network")
ptrain.add_argument("--plot-data-distribution", action="store_true",
                    help="Plot the distribution of the loaded data.")
ptrain.add_argument("-p", "--prefix", type=str, default="",
                    help="Prefix to add in front of analysis outputs and stored"
                         "model file name.")
ptrain.add_argument("--pruning", type=PruningTypes,
                    choices=_init_get_pruning_types_list(),
                    default=PruningTypes.IntraInter,
                    help="Pruning to apply to the training/validation/test "
                         "data. Options are:\n"
                         "\toff: no pruning at all. The sets are loaded as is\n"
                         "\tinter: entries are pruned between the data sets ("
                         "priority: test set, validation set, training set)\n"
                         "\tintra: entries are pruned within a data set\n"
                         "\tintra_inter: entries first pruned within data set,"
                         "then inter data sets.")
ptrain.add_argument("-pf", "--problem-filter", type=re.compile,
                    action="append", default=[],
                    help="A problem file name has to match the regex otherwise"
                         "it is not registered. By default no regex matches are"
                         "required. This argument can be given any number of"
                         "time to add additional filters (the file name has"
                         "to match ALL regexes)")
ptrain_mutex_multiple_iterations.add_argument(
    "--repetitions", type=apt.int_positive, default=1,
    help="Works only together with '--execute' (todo make this"
         "work without '--execute'). Adds '--array=0-N' as "
         "first"
         "argument into '--args'. Provides"
         "after the arguments of '--args' N arguments "
         "representing regular expressions not matching any "
         "problem (as they are all '-') (This is done to reuse"
         "the same mechanism as --cross-validation)")
ptrain.add_argument("--sample-type",  default=tc.SampleTypes.all,
                    choices=tc.SampleTypes.name2type.values(),
                    type=tc.SampleTypes.get,
                    help="Loads only samples of this given type. Available "
                         "types are: \n\t %s" % _init_join_samples_types())
ptrain.add_argument("--samples-per-problem",  default=None,
                    type=apt.int_positive,
                    help="How many problems to load at most per sampled "
                         "problem. If not specified all samples belonging to a "
                         "problem are loaded. This requires that all samples "
                         "have the meta fields %s and %s set. ASSUMPTION: "
                         "SAMPLES FROM THE SAME PAIR ARE IN CONSECUTIVE ORDER "
                         "IN THE DATA FILES!" %
                         (MetaFields.PROBLEM_HASH, MetaFields.MODIFICATION_HASH)
                    )
ptrain.add_argument("--samples-total-testing",  default=None,
                    type=apt.int_positive,
                    help="Limit the total number of samples to load for the "
                         "test data. This does not limit the data for the "
                         "training or verification data. %s" %
                         SAMPLE_RESTRICTION_HINT)
ptrain.add_argument("--samples-total-training",  default=None,
                    type=apt.int_positive,
                    help="Limit the total number of samples to load for the "
                         "training data. This does not limit the data for the "
                         "test or verification data. Due to pruning "
                         "duplicates from verification & test data in training "
                         "data, the final number of samples in the training "
                         "data set can be smaller! Todo: fix that %s" %
                         SAMPLE_RESTRICTION_HINT)
ptrain.add_argument("--samples-total-verifying",  default=None,
                    type=apt.int_positive,
                    help="Limit the total number of samples to load for the "
                         "verification data. This does not limit the data for "
                         "the "
                         "train or test data. Due to pruning "
                         "duplicates from test data in training "
                         "data, the final number of samples in the training "
                         "data set can be smaller! Todo: fix that %s" %
                         SAMPLE_RESTRICTION_HINT)

ptrain.add_argument("--seed", default=None, type=int,
                    help="Use a specific random seed")
ptrain.add_argument("--skip", action="store_true",
                    help=("If set, then missing sample files are skipped, "
                          "otherwise every problem file is expected to have "
                          "sample file."))
ptrain.add_argument("--skip-if-trained", action="store_true",
                    help="Skip training (works only without '--execute' if the"
                         "requested network model files exist already.")
ptrain.add_argument("--skip-if-flag", action="store_true",
                    help="Skip training (works only without '--execute' if a "
                         "skip flag exists (after training the skip flag file "
                         "is created).")
ptrain.add_argument("--skip-if-running", action="store_true",
                    help="Skip training (works only without '--execute' if a "
                         "valid running flag exists (prior to training set and "
                         "if not crashed afterwards deleted).")
ptrain.add_argument("--skip-magic", action="store_true",
                    help=("Tries to load the sample without performing a check"
                          "that it uses the right reader for the sample file"
                          "format (use case old sample files without magic "
                          "word. USE ONLY IF YOU KNOW WHAT YOU ARE DOING)"))
ptrain_mutex_execute.add_argument(
    "--slurm", action="store_true",
    help="Executes the training via submitting it to a slurm environment "
         "(this sets '--execute' and enables all options depending on "
         "'--execute').")
ptrain.add_argument("--slurm-dependency", default=None,
                    type=apt.slurm_dependency,
                    help="Only valid in combination with --slurm. Adds the "
                         "given"
                         "value to the slurm command as dependency. The "
                         "following"
                         "special sequences are defined which will be "
                         "expanded:\n"
                         "{key:value{;key:value)*}: key in (u, user) selects "
                         "jobs"
                         "of user, key in (p, partition) selects jobs of "
                         "partition. "
                         "If multiple given, then jobs matching all "
                         "conditions are "
                         "selected.\n"
                         "\nExample"
                         "values: afterany:9457635:{u:myuser}")
ptrain.add_argument("--slurm-summarize", action="store_true",
                    help="Only valid in combination with --slurm. Adds an "
                         "output summarization step after the training jobs.")
ptrain.add_argument("--stop-after-initialization", action="store_true",
                    help="Loads data and initialize network. Afterwards stops.")
ptrain.add_argument("-sdt", "--sub-directory-training", action="store_true",
                    help="Changes training from one network on the data within "
                         "all given directories (in the directory group)"
                         "to training a single network"
                         "per directory (and subdirectory) which contains a"
                         "domain.pddl file and at least one *.data file (for "
                         "those directories selected, the data is loaded from "
                         "them and from subdirectories like before)")
ptrain.add_argument("-t", "--test", type=re.compile, default=None,
                    help="Regex for identifying data set files to use as test"
                         "data.")
ptrain.add_argument("-ts", "--test-split", default=0.0,
                    type=apt.restricted_type(
                        apt.float_interval(0., 1.),
                        restrict_total_split_data),

                    help="Fraction of the trainings data to split off for the "
                         "test data (this is additional to '--test')")
ptrain.add_argument("-v", "--validation", type=re.compile,
                    default=None,
                    help="Regex for identifying data set files to use as "
                         "validation data (test data sets are excluded).")
ptrain.add_argument("-vs", "--validation-split", default=0.0,
                    type=apt.restricted_type(
                        apt.float_interval(0., 1.),
                        restrict_total_split_data),
                    help="Fraction of the trainings data to split off for the "
                         "validation data (this is additional to "
                         "'--validation'.")


def get_directory_groups(directories, directory_filters,
                         sub_directory_training):
    """
    Returns the directory groups for training. On each group a training run
    will be done.
    :param directories: Base directory grouping
    :param directory_filters: filter to remove directories
    :param sub_directory_training: looks for subdirectories containing
        domain.pddl files. Every such directory becomes its own group.
    :return:
    """

    def match_all(path_dir):
        return all([directory_filter.match(path_dir)
                    for directory_filter in directory_filters])

    if not sub_directory_training:
        directory_groups = [
            [g for g in group if match_all(g)]
            for group in directories
            if os.path.isfile(os.path.join(group[0], "domain.pddl"))]
    else:
        directory_groups = []
        todo = [g for group in directories for g in group]
        while len(todo) > 0:
            next_dir = todo.pop()
            if (os.path.isfile(os.path.join(next_dir, "domain.pddl")) and
                    match_all(next_dir)):
                directory_groups.append([next_dir])
            todo.extend([os.path.join(next_dir, sub)
                         for sub in os.listdir(next_dir)
                         if os.path.isdir(os.path.join(next_dir, sub))])
    return tools.misc.sort_nicely(
        [dg for dg in directory_groups if len(dg) > 0],
        sort_key=lambda d: d[0])


def parse_training_args(argv):
    options = ptrain.parse_args(argv)

    if options.slurm:
        options.execute = "sbatch"
    if options.slurm_dependency:
        assert options.slurm, \
            "Option --slurm-dependency requires option --slurm"
    if options.slurm_summarize:
        assert options.slurm, "Option --slurm-summarize requires option --slurm"

    if options.cross_validation is not None:
        assert options.execute is not None, \
            "Requires '--execute' to use '--cross-validation'"
        assert (options.test is None and options.test_split == 0.0 and
                options.validation is None and options.validation_split ==
                0.0), (
            "Cannot provide test/validation splits when doing cross validation")
        if options.args is None:
            options.args = "--array=0-%i" % (options.cross_validation - 1)
        else:
            options.args = "--array=0-%i " % (
                        options.cross_validation - 1) + options.args

    if options.repetitions != 1:
        assert options.execute is not None, \
            "Requires '--execute' to use --repetitions != 1"
        assert (options.test is None and options.validation is None), (
            "Cannot provide test/validation regexes when doing repetitions")
        if options.args is None:
            options.args = "--array=0-%i" % (options.repetitions - 1)
        else:
            options.args = "--array=0-%i " % (
                        options.repetitions - 1) + options.args

    directory_groups = get_directory_groups(options.directory,
                                            options.directory_filter,
                                            options.sub_directory_training)
    assert len(directory_groups) > 0, "No valid list of directories found."

    options.initialize = {k: v for k, v in options.initialize}
    options.finalize = {k: v for k, v in options.finalize}

    return options, directory_groups


def split_training_blocks(args):
    runs = [[]]
    for arg in args:
        if arg == ARG_NEW_TRAINING:
            runs.append([])
        else:
            runs[-1].append(arg)
    return [x for x in runs if len(x) > 0]


@tm.static_var("cache", None)
def get_parser_argument_keys():
    if get_parser_argument_keys.cache is None:
        get_parser_argument_keys.cache = set(
            [key
             for action in ptrain._actions
             for key in action.option_strings])
    return get_parser_argument_keys.cache


""" -------------------------------- Misc -----------------------------------"""



""" ------------------------------- Common Training ------------------------"""

def create_directory_samplers(
        directories, directory_filter, general_task_filters,
        data_set_task_filters, min_depth, max_depth):
    all_tasks = []
    all_samplers = []
    for no, task_filters in enumerate(data_set_task_filters):
        sampler = None
        if task_filters is not None:
            task_filters = general_task_filters + task_filters

            sampler = DirectorySampler(
                None, directories, directory_filter, task_filters,
                None, all_tasks, max_depth, min_depth,
                merge=True)
            all_tasks.extend(sampler.iterable)
        all_samplers.append(sampler)

    return all_tasks, all_samplers



def load_domain_properties(directories, all_tasks, full_domain_properties):
    print("Start analysing Domain:")
    start_time = time.time()
    path_domain = os.path.join(directories[0], "domain.pddl")

    path_load = os.path.join(os.path.dirname(path_domain),
                             "domain_properties.json" if
                             full_domain_properties else
                             "domain_properties_no_statics.json")
    path_store = None if os.path.exists(path_load) else path_load
    path_load = path_load if path_store is None else None

    domain_properties = DomainProperties.get_property_for(
        path_domain=path_domain,
        paths_problems=all_tasks,
        no_gnd_static=not full_domain_properties,
        load=path_load,
        store=path_store,
        verbose=1)
    _ = tm.timing(start_time, "Domain analysing time: %ss")
    return domain_properties


class DataSetContainer(object):
    def __init__(self):
        self.states = set()

    def add(self, data, type=None, fields=None):
        assert len(data) == 1
        self.states.add(data[0])
        return 1

    def empty(self):
        return self.size() == 0

    def size(self):
        return len(self.states)

    def finalize(self):
        pass

    def clear(self):
        self.states = set()


def load_data(options, directories, state_format, reference_states, dry=False):
    """

    :param options:
    :param directories:
    :param state_format:
    :param dry: returns only which problems are in which set (test, valid,
    train)
    :return:
    """
    streams = options.input
    fold_idx = REGEX_FOLD.match(options.prefix)
    fold_idx = None if fold_idx is None else fold_idx.group(1)

    # None = skip, order = Param for test, validation, training data set
    data_set_task_filters = [
        None, None, []
    ]
    all_tasks, all_samplers = create_directory_samplers(
        directories, options.directory_filter, options.problem_filter,
        data_set_task_filters, options.min_depth, options.max_depth
    )

    if dry:
        return [[] if x is None else x.iterable for x in all_samplers]


    # Load Domain Properties (and add to network)
    assert not (options.domain_properties and
                options.domain_properties_no_statics)
    domain_properties = None
    if options.domain_properties or options.domain_properties_no_statics:
        domain_properties = load_domain_properties(
            directories, all_tasks, options.domain_properties)


    # Actually load data
    datas = []
    bridge_cur_mem = 0
    bridge_loaded_tasks = []
    nb_samples_seen = 0
    nb_references_seen = 0
    data_container = DataSetContainer()
    for no, sampler in enumerate(all_samplers):
        new_data = None
        if sampler is not None:
            bridge = LoadSampleBridge(
                streams=StreamContext(streams=streams),
                fields=options.fields,
                format=state_format, prune=None,
                fprune=None,
                skip=options.skip, skip_magic=options.skip_magic,
                forget=options.forget,
                domain_properties=domain_properties,
                max_mem=options.maximum_data_memory,
                sample_types=options.sample_type.subtypes,
                samples_per_problem=options.samples_per_problem,
                max_container_samples=None,
                reference_states=None, #reference_states,
                provide=True,
            )
            bridge._cur_mem = bridge_cur_mem
            sampler.sbridges = [bridge]

            sampler.initialize()
            sampler.sample(do_merge=True, data_container=data_container)
            sampler.finalize()
            bridge_loaded_tasks.append(bridge.loaded_tasks)
            bridge_cur_mem = bridge.current_memory_usage
            nb_references_seen += bridge.loaded_tasks
            nb_samples_seen += bridge.loaded_samples


    print("seen references/loaded _samples: %i/%i" % (nb_references_seen, nb_samples_seen))
    return data_container



"""-------------------- Execute Branch Only ---------------------------------"""


def get_fold_regexes(nb_folds, tasks):
    tasks = tools.misc.sort_nicely(tasks)
    fold_tasks = []
    size = int(len(tasks) / nb_folds)
    for i in range(nb_folds):
        fold_tasks.append(tasks[i * size: (len(tasks) if i == (nb_folds - 1)
                                           else (i + 1) * size)])

    return [tm.get_common_prefix_suffix_regex(*fold) for fold in fold_tasks]


def get_execute_command(options, argv):
    new_command = list(argv)
    new_command.insert(0, options.execute)
    idx_start_arguments = 1
    if options.slurm_dependency:
        new_command.insert(1, "--dependency")
        new_command.insert(2, options.slurm_dependency)
        new_command.insert(3, "--kill-on-invalid-dep=yes")
        new_command, _ = apt.extract_and_remove_arguments(
            new_command, ["--slurm-dependency"], get_parser_argument_keys())
        idx_start_arguments += 3

    for params in [
        ["-e", "--execute"],
        ["--slurm"],
        ["--slurm-summarize"],
        ["--cross-validation"],
        ["--repetitions"],
        ["-a", "--args"],
        ["-sdt", "--sub-directory-training"],
        ["-d", "--directory"],
    ]:
        new_command, _ = apt.extract_and_remove_arguments(
            new_command, params, get_parser_argument_keys())

    if options.args is not None:
        execute_pre_args = shlex.split(options.args)
        new_command[idx_start_arguments:idx_start_arguments] = execute_pre_args
        idx_start_arguments += len(execute_pre_args)

    return new_command, idx_start_arguments


""" -------------------Local Training Branch Only ---------------------------"""




""" -------------- Execute the training in another process ------------------"""


def train_execute(options, argv, directory_groups):
    slurm_job_ids = []
    base_command, new_idx_start_arguments = get_execute_command(
        options, list(argv))

    for dg in directory_groups:
        print("DIRECTORY GROUP", dg)
        next_command = list(base_command)
        idx_start_arguments = new_idx_start_arguments

        # Check that call would be valid
        load_data(
            options, dg, options.format, None, dry=True)


        next_command[idx_start_arguments + 1: idx_start_arguments + 1] = (
                ["--directory"] + dg)

        print("Call executable: ", next_command)
        if options.dry:
            continue
        sub_out = decoder(subprocess.check_output(next_command))
        if options.slurm_summarize:
            job_id = REGEX_SUBMITTED_BATCH_JOB.match(sub_out)
            assert job_id is not None, \
                "Error submitting slurm job: %s" % str(next_command)
            slurm_job_ids.append(job_id.group(1))


PATTERN_PROBLEM = re.compile(r"^.*p\d+\.pddl$")
PATTERN_FACT = re.compile(r"(\([^\(\)\s)]+(\s+[^\(\)\s)]+)*\s*\))")
IGNORE_ATOMS = set([
    "total-cost()",
    "Atom road-length(",
])
def convert_atoms(atom):
    assert atom[0] == "("
    assert atom[-1] == ")"
    atom = atom[1:-1]
    s = atom.split()
    return "Atom %s(%s)" % (s[0], ", ".join(s[1:]))


def extract_ns_state(pddl, set_atoms_all, atoms_flexible):
    idx1 = pddl.find("(:init")
    assert idx1 != -1
    idx1 += len("(:init")
    idx2 = pddl.find("(:", idx1)
    init = pddl[idx1:idx2]
    init = PATTERN_FACT.findall(init)
    init = [convert_atoms(x[0]) for x in init]
    init = [x for x in init if not any([x.find(ignore) > -1 for ignore in IGNORE_ATOMS])]
    init = set(init)

    assert all(x in set_atoms_all for x in init), ", ".join([x for x in init if x not in set_atoms_all])
    state = tuple([int(x in init) for x in atoms_flexible])
    return state

def run_generator(path):
    pddl = subprocess.check_output([path])
    return pddl


def train_local(options, directory_groups, process, start_time):
    for idx_dg, dg in enumerate(directory_groups):
        print("Processing Directory Group %i: %s" % (idx_dg, dg))
        file_atoms = os.path.join(dg[0], "atoms.json")
        file_generator = os.path.join(dg[0], "run_generator.sh")
        assert os.path.isfile(file_atoms)
        with open(file_atoms, "r") as f:
            atoms = json.load(f)
        set_atoms_all = set(atoms["PDDL_ATOMS"])
        atoms_flexible = atoms["PDDL_ATOMS_FLEXIBLE"]

        old_tasks = [os.path.join(dg[0], item)
                     for item in os.listdir(dg[0])
                     if PATTERN_PROBLEM.match(item)]


        # file_reference_states = os.path.join(dg[0], "ns_init_states.json")
        # if not os.path.exists(file_reference_states):
        #     print("Reference states file missing.")
        #     continue
        # with open(file_reference_states, "r") as f:
        #     reference_states = json.load(f)
        # reference_states = [tuple(x) for x in reference_states]
        if options.dry:
            continue

        start_memory = process.memory_info().rss

        try:
            data = load_data(
                options, dg, options.format, None)#, reference_states)
            print("aloha")
            used_tasks = {}
            for no, path in enumerate(old_tasks):
                with open(path, "r") as f:
                    content = f.read()
                used_tasks[path] = extract_ns_state(
                    content, set_atoms_all, atoms_flexible) in data.states

            replaced = []
            for path, flag in used_tasks.items():
                if flag:
                    shutil.move(path, path + ".bak")
                    replaced.append(path)
                    while True:
                        pddl = run_generator(file_generator)
                        ns = extract_ns_state(pddl, set_atoms_all, atoms_flexible)
                        if ns not in data.states:
                            with open(path, "w") as f:
                                f.write(pddl)
                            break

            with open(os.path.join(dg[0], "replaced_pddls_files.json"), "w") as f:
                json.dump(replaced, f)
            print("Replaced", len(replaced))


        except SampleRequirementException as e:
            print("Exit:", e)
            continue
        end_memory = process.memory_info().rss

def train(argv):
    print("Startup time: %s" % str(datetime.datetime.now()))
    print("Call: %s" % " ".join(argv))

    start_time = time.time()
    options, directory_groups = parse_training_args(argv)
    start_time = tm.timing(start_time, "Parsing time: %ss")

    process = psutil.Process(os.getpid())
    random.seed(options.seed)

    if options.execute is None:
        return train_local(options, directory_groups, process, start_time)
    else:
        return train_execute(options, argv, directory_groups)


def run(args):
    runs = split_training_blocks(args)

    for training_run in runs:
        train(training_run)


if __name__ == "__main__":
    run(sys.argv[1:])
