#!/usr/bin/env python

import argparse
import collections
import datetime
import json
import matplotlib as mpl
mpl.use('agg')
import matplotlib.pyplot as plt
import numpy as np
import os
import platform
import re
import sys


ASSUMPTIONS = [
    "Within a directory containing problems, generating samples is equally "
    "difficult for all problems and equally difficult within problems.",
    "Training directory is given by --directory",
    "Prefix is given by --prefix"
    "--sample-type is used"
    "--samples-per-problems may be used, but only with the value 1",
    "the input streams have a section suffix=X, and all data sets we use are "
    "detectable as problem.X"
    "none of the arguments named above is a key in the job stats"
]

JOB_STATS = "job_stats.json"
DATA_SET_SIZES = "data_set_sizes.json"

NB_SAMPLES = "#samples"
NB_PROBLEMS = "#problems"

# JOB STATS KEYS
CALL_STRING = "call_string"
TIMESTAMP = "timestamp"
TRAINING_DATA = "data_set_sizes"
TRAINING_DATA_TRAIN = "train"
TRAINING_DATA_VALID = "validation"
TRAINING_DATA_TEST = "test"

TRAINING_TIMES = "phase_times"
TRAINING_TIME_LOAD = "data_loading"
TRAINING_TIME_INIT = "network_initialization"
TRAINING_TIME_TRAIN = "network_training"
TRAINING_TIME_PARSE = "option_parsing"

FINISHED = "finished"

SKIP_EXECUTION = "skip_execution"

# CALL STRING ARGUMENTS
CS_PREFIX = "--prefix"
CS_DIRECTORY = "--directory"
CS_SAMPLE_TYPE = "--sample-type"
CS_SAMPLES_PER_PROBLEM = "--samples-per-problem"
# not actually a parameter, but extracted from a parameter
CS_STREAM_SUFFIX = "SUFFIX"

# SAMPLE TYPES
SAMPLE_TYPE_INIT = "init"
SAMPLE_TYPE_INTER = "inter"
SAMPLE_TYPE_PLAN = "plan"
SAMPLE_TYPE_ALL = "all"

GROUP_NAME_ALL = "all"


PATTERN_EXTRACT_STREAM_SUFFIX = re.compile("suffix=([^\)]+)\)")


def type_directory(arg):
    assert os.path.isdir(arg), arg
    return arg

def type_try_classes(*types):
    def convert_type(arg):
        for t in types:
            try:
                return t(arg)
            except:
                pass
        assert False, arg
    return convert_type


parser = argparse.ArgumentParser()
parser.add_argument("--prefix", type=re.compile, action="append", default=[],
                    help="Regex for a prefix to analyse. Multiple regexes can "
                         "be used to analyse multiple prefixes independently.")
parser.add_argument("--directory-job-stats", type=type_directory,
                    action="append", default=[],
                    help="Path to a directory in which to look recursively for"
                         "%s." % JOB_STATS)
parser.add_argument("--directory-job-stats-recursive-depth", default=1,
                    help="Maximum recursion depth for lookng for job stats in "
                         "one of the given directories")
parser.add_argument("--regex-job-stats", type=re.compile, action="append",
                    default=[],
                    help="A job-stats file path has to match all regexes to be"
                         "considered.")
parser.add_argument("--group-problem-directories",
                    type=type_try_classes(int, re.compile), nargs=2,
                    default=None,
                    help="Only networks trained for directories matching are "
                         "considered. Directories are grouped by sorting "
                         "together regexes where the regex matched group"
                         "are equal. Specify the group index as second "
                         "parameter to this argument. (e.g. '.* 0', 'a(.*) 1')")
parser.add_argument("--data-set-size-cache",
                    default="_data_set_size_cache.json",
                    help="Path where to cache data about the used data sets")
parser.add_argument("--add-unsuccessful-trainings", action="store_true")
parser.add_argument("--skip-test-size", action="store_true",
                    help="Does not count the size of the test set for "
                         "calculating the sampling time.")
parser.add_argument("--sampling-time-per-file", type=float, default=2,
                    help="Sampling time per problem in hours.")

parser.add_argument("--outdir", default=".",
                    help='Directory to store the outputs of this scrit')


def adapt_paths_from_server(path):
    if platform.node() == "dmi-aoede":
        return path.replace("infai", "home")
    elif platform.node().find("cluster.bc2.ch") > -1:
        return path
    else:
        assert False


# cache: {(path, suffix): (#problems, #samples, #files)}
def get_from_data_set_size(cache, path, *suffixes):
    path = adapt_paths_from_server(path)
    total_nb_problems, total_nb_samples, total_nb_files = 0, 0, 0
    for suffix in suffixes:
        key = (path, suffix)
        if key not in cache:
            assert os.path.isdir(path)
            file_sizes = os.path.join(path, DATA_SET_SIZES)
            assert os.path.isfile(file_sizes)
            with open(file_sizes, "r") as f:
                sizes = json.load(f)
            nb_problems, nb_samples, nb_files = 0, 0, 0
            for data_set, stats in sizes.items():
                if data_set.endswith(suffix):
                    nb_samples += stats[NB_SAMPLES]
                    nb_problems += stats[NB_PROBLEMS]
                    nb_files += 1
            cache[key] = (nb_problems, nb_samples, nb_files)
        nb_problems, nb_samples, nb_files = cache[key]
        total_nb_problems += nb_problems
        total_nb_samples += nb_samples
        total_nb_files += nb_files
    return (total_nb_problems, total_nb_samples, total_nb_files)


def load_data_set_size_cache(path):
    if os.path.exists(path):
        with open(path, "r") as f:
            return json.load(f)


def save_data_set_size_cache(cache, path):
    with open(path, "w") as f:
        json.dump(cache, f)



def find_job_stats_files(dir_roots, recursion_depth, filter_regex):
    for dir_root in dir_roots:
        todo = [(dir_root, recursion_depth)]
        while len(todo) > 0:
            dir_current, depth = todo.pop()
            for item in os.listdir(dir_current):
                path_item = os.path.join(dir_current, item)
                if os.path.isdir(path_item) and depth > 0:
                    todo.append((path_item, depth - 1))
                elif (item == JOB_STATS and all(regex.match(path_item)
                                                for regex in filter_regex)):
                    yield path_item


def _extract_argument(call_string, arg, default=None, find_all=False):
    """
    Returns the first parameter provided (seperation by whitespaces is used)
    after the first occurrence of the given argument.
    :param call_string: string in which to search for the argument
    :param arg: argument (a string) to look for
    :param default: default value to return if arg was not found in call_string
    :param find_all: the first parameter after ALL occurrences of 'arg'.
                     'default' has no effect anymore

    :return:
    """
    found = []
    idx_start = 0
    while True:
        idx_start = call_string.find(arg, idx_start)
        if idx_start == -1:
            break
        idx_start += len(arg)
        while call_string[idx_start] == " ":
            idx_start += 1
            assert idx_start < len(call_string)
        idx_end = call_string.find(" ", idx_start)

        found.append(call_string[idx_start: idx_end]
                     if idx_end != -1 else
                     call_string[idx_start:])

        if not find_all or idx_end == -1:
            break
        idx_start = idx_end

    if not find_all and len(found) == 0:
        return default
    elif find_all:
        return found
    else:
        return found[0]


def parse_and_convert_job_stats(job_stats):
    "--sample-type plan --samples-per-problem 1"
    call_string = job_stats[CALL_STRING]
    job_stats[TIMESTAMP] = datetime.datetime.strptime(
        job_stats[TIMESTAMP], "%Y-%m-%d %H:%M:%S.%f")

    prefix = _extract_argument(call_string, CS_PREFIX)
    assert prefix is not None
    directory = _extract_argument(call_string, CS_DIRECTORY)
    assert directory is not None
    job_stats[CS_DIRECTORY] = directory


    sample_type = _extract_argument(call_string, CS_SAMPLE_TYPE)
    assert sample_type is not None
    job_stats[CS_SAMPLE_TYPE] = sample_type
    samples_per_problem = _extract_argument(call_string, CS_SAMPLES_PER_PROBLEM)
    samples_per_problem = (None if samples_per_problem is None
                           else int(samples_per_problem))
    assert samples_per_problem is None or samples_per_problem is 1
    job_stats[CS_SAMPLES_PER_PROBLEM] = samples_per_problem

    input_streams = _extract_argument(call_string, "--input", find_all=True)
    input_streams = [PATTERN_EXTRACT_STREAM_SUFFIX.search(input_stream).group(1)
                     for input_stream in input_streams]
    job_stats[CS_STREAM_SUFFIX] = input_streams

    return directory, prefix


def find_job_stats(dir_roots, recursion_depth, prefixes, filter_regex,
                   add_unsuccessful):
    organized_job_stats = [{} for _ in prefixes]  # [{prefix: stats}, ...]
    for file_job_stats in find_job_stats_files(
            dir_roots, recursion_depth, filter_regex):
        with open(file_job_stats, "r") as f:
            all_job_stats = json.load(f)

        for job_stats in all_job_stats.values():
            if SKIP_EXECUTION in job_stats:
                continue
            if not add_unsuccessful and not job_stats[FINISHED]:
                continue

            directory, prefix = parse_and_convert_job_stats(job_stats)

            key = (directory, prefix)
            for no, regex in enumerate(prefixes):
                if regex.match(prefix):
                    if (key not in organized_job_stats[no] or
                            organized_job_stats[no][key][TIMESTAMP] <
                            job_stats[TIMESTAMP]):
                        organized_job_stats[no][key] = job_stats
    return organized_job_stats


def group_job_stats(all_job_stats, group_directory):
    def get_group_name(directory):
        if group_directory is None:
            return "unfiltered"
        match = group_directory[0].match(directory)
        return None if match is None else match.group(group_directory[1])

    new_structure = []  # [{directory_group_name : [job_stat, job_stat, ...]}]
    for prefix_job_stats in all_job_stats:
        new_prefix_structure = collections.defaultdict(list)
        for (directory, prefix), job_stats in prefix_job_stats.items():
            group_name = get_group_name(directory)
            if group_name is not None:
                new_prefix_structure[group_name].append(job_stats)
        new_structure.append(new_prefix_structure)
    return new_structure



def get_sampling_times(data_set_size_cache, all_job_stats,
                       skip_test, time_per_file):
    all_sampling_times = []
    for prefix_job_stats in all_job_stats:
        prefix_sampling_times = {}
        for group_name, group_job_stats in prefix_job_stats.items():
            group_sampling_times = []
            for job_stats in group_job_stats:
                directory = job_stats[CS_DIRECTORY]
                sample_type = job_stats[CS_SAMPLE_TYPE]
                samples_per_problem = job_stats[CS_SAMPLES_PER_PROBLEM]

                used_samples = sum(x for x, y in zip(
                    [job_stats[TRAINING_DATA][TRAINING_DATA_TRAIN],
                     job_stats[TRAINING_DATA][TRAINING_DATA_VALID],
                     job_stats[TRAINING_DATA][TRAINING_DATA_TEST]],
                    [True, True, not skip_test])
                    if y)

                nb_problems, nb_samples, nb_files = get_from_data_set_size(
                    data_set_size_cache, directory,
                    *job_stats[CS_STREAM_SUFFIX])

                if sample_type == SAMPLE_TYPE_INIT:
                    assert samples_per_problem is None
                    available_samples = nb_problems
                elif sample_type == SAMPLE_TYPE_INTER:
                    assert samples_per_problem is None or samples_per_problem == 1
                    available_samples = nb_problems
                else:
                    if samples_per_problem is None:
                        available_samples = nb_samples
                    elif samples_per_problem == 1:
                        available_samples = nb_problems
                    else:
                        assert False

                #assert used_samples <= available_samples, "%s: %i/%i" % (job_stats[CS_SAMPLE_TYPE], used_samples, available_samples)
                sampling_time = (float(time_per_file) * 60 * 60 * nb_files *
                                 used_samples / available_samples)

                group_sampling_times.append(sampling_time)
            prefix_sampling_times[group_name] = group_sampling_times
        all_sampling_times.append(prefix_sampling_times)
    return all_sampling_times


def get_training_times(all_job_stats):
    all_training_times = []
    for prefix_job_stats in all_job_stats:
        prefix_training_times = {}
        for group_name, group_job_stats in prefix_job_stats.items():
            group_training_times = []
            for job_stats in group_job_stats:
                times = job_stats[TRAINING_TIMES]
                group_training_times.append(sum(
                    [0 if x is None else x for x in
                     [times[TRAINING_TIME_PARSE],
                      times[TRAINING_TIME_INIT],
                      times[TRAINING_TIME_LOAD],
                      times[TRAINING_TIME_TRAIN]]]
                ))
            prefix_training_times[group_name] = group_training_times
        all_training_times.append(prefix_training_times)
    return all_training_times


def boxplot_times(file_plot, title, prefixes, *times, **kwargs):
    """

    :param file_plot: path where to store the plot
    :param title: title for subplots. Has to contain one %s where the group
                  name for the subplot will be inserted
    :param times: Argument list of format:
                  [{group_name: [time, time, ...], ...}, ...]
                  One {group_name: ....} dict per prefix
    :return:
    """
    assert len(prefixes) == len(times[0])
    assert all(len(times[0]) == len(t) for t in times)
    scale = float(kwargs.pop("scale", 1))
    assert len(kwargs) == 0

    group_names = ([GROUP_NAME_ALL] +
                   sorted(set([g for t in times for p in t for g in p.keys()])))

    fig = plt.figure(figsize=(10, len(group_names) * len(prefixes) * 1.4))
    for no_group, group_name in enumerate(group_names):
        times_per_prefix = [None for _ in prefixes]  # [[1.2, 3., 0.7, ...], ...]
        for t in times:
            for no_prefix, prefix in enumerate(prefixes):
                times_prefix = t[no_prefix]
                new_times_per_prefix = []
                for tmp_group_name in (
                        group_names[1:]
                        if group_name == GROUP_NAME_ALL
                        else [group_name]):
                    if tmp_group_name in times_prefix:
                        new_times_per_prefix.extend(times_prefix[tmp_group_name])
                new_times_per_prefix = np.array(new_times_per_prefix)
                if times_per_prefix[no_prefix] is None:
                    times_per_prefix[no_prefix] = new_times_per_prefix
                else:
                    times_per_prefix[no_prefix] += new_times_per_prefix

        ax = fig.add_subplot(len(group_names), 1, no_group + 1)
        ax.boxplot([x*scale for x in times_per_prefix], labels=prefixes, vert=False)
        ax.set_title(title % group_name)
    fig.tight_layout()
    fig.savefig(file_plot)


def run(argv):
    data_set_size_cache = {}
    options = parser.parse_args(argv)

    assert options.sampling_time_per_file > 0

    job_stats = find_job_stats(
        options.directory_job_stats,
        options.directory_job_stats_recursive_depth,
        options.prefix,
        options.regex_job_stats,
        options.add_unsuccessful_trainings,
    )

    job_stats = group_job_stats(job_stats, options.group_problem_directories)

    sampling_times = get_sampling_times(
        data_set_size_cache,
        job_stats,
        options.skip_test_size,
        options.sampling_time_per_file
    )

    training_times = get_training_times(job_stats)

    if not os.path.exists(options.outdir):
        os.makedirs(options.outdir)
    with open(os.path.join(options.outdir, "times.json"), "w") as f:
        json.dump({"prefixes": [x.pattern for x in options.prefix],
                   "sampling_times": sampling_times,
                   "training_times": training_times}, f)

    label_prefixes = [x.pattern for x in options.prefix]
    boxplot_times(os.path.join(options.outdir, "total_time.pdf"),
                  "Total time for %s in hrs", label_prefixes,
                  sampling_times, training_times, scale=1.0 / (60 * 60))
    boxplot_times(os.path.join(options.outdir, "sampling_time.pdf"),
                  "Sampling time for %s in hrs", label_prefixes,
                  sampling_times, scale=1.0 / (60 * 60))
    boxplot_times(os.path.join(options.outdir, "training_time.pdf"),
                  "Training time for %s in sec", label_prefixes,
                  training_times)


    print("Done.")



if __name__ == "__main__":
    print("This script uses a lot of assumptions for its calculation. If you do"
          "not know if they apply to use, then do not use this script. "
          "Assumptions:\n\t%s" "\t\n".join(ASSUMPTIONS))
    run(sys.argv[1:])
