#!/usr/bin/env python
from __future__ import print_function

import sbatch_tools as stools

import argparse
import os
import sys


SCRIPT_TRAINING = "run_training.sh"


def type_isdir(arg):
    if not os.path.isdir(arg):
        raise ValueError()
    return os.path.abspath(arg)


def type_is_training_dir(arg):
    arg = type_isdir(arg)
    if not os.path.isfile(os.path.join(arg, SCRIPT_TRAINING)):
        raise ValueError()
    return arg


def type_positive_or_zero_int(arg):
    i = int(arg)
    if i < 0 or i != float(arg):
        raise ValueError()
    return i


def type_positive_int(arg):
    i = type_positive_or_zero_int(arg)
    if i == 0:
        raise ValueError()
    return i


parser = argparse.ArgumentParser()
parser.add_argument("training_directories", type=type_isdir, nargs="+",
                    help="List of directories containing the {SCRIPT_TRAINING}"
                         "scripts. The scripts will be chained in the order"
                         "the directories are given in.".format(**locals()))
parser.add_argument("--slurm-dependency", default=None,
                    help="Slurm dependency for the first job in the chain. All"
                         "other jobs will depend on the previous job in the "
                         "chain.")
parser.add_argument("--parallel-chains", type=type_positive_int, default=1,
                    help="Number of parallel chains. If N directories are "
                         "given and N parallel chains, then there is no "
                         "chaining done at all. All chain heads will wait for "
                         "'--slurm-dependency'.")
parser.add_argument("--dry", action="store_true",
                    help="starts the job scripts in dry mode")


def run(options):
    previous_job_id = [options.slurm_dependency
                       for _ in range(options.parallel_chains)]
    for no, training_directory in enumerate(options.training_directories):
        no_chain = no % options.parallel_chains

        os.chdir(training_directory)
        command = ["./%s" % SCRIPT_TRAINING]
        if previous_job_id[no_chain] is not None:
            command.extend(["--slurm-dependency", previous_job_id[no_chain]])
        if options.dry:
            command.append("--dry")

        job_ids = stools.call_subprocess(command, multijobs=True)
        previous_job_id[no_chain] = "after:%s" % ":".join(
            str(x) for x in job_ids)


if __name__ == "__main__":
    run(parser.parse_args(sys.argv[1:]))
