#!/usr/bin/env python


from __future__ import print_function

import argparse
from collections import defaultdict
import gzip
import json
from matplotlib import pyplot as plt
import numpy as np
import os
import re
import shutil
import sys


SUFFIX_PDDL = ".pddl"

FILE_STATS = "data_set_sizes.json"

KEY_SAMPLES = "#samples"
KEY_PROBLEMS = "#problems"

FOLDS_COUNTS = 10
FOLDS_SIZE = 20

parser = argparse.ArgumentParser()
parser.add_argument("-d", "--directory", type=str, action="append", default=[],
                    help="Directory under which all directories containing"
                         "domain.pddl shall be analysed.")
parser.add_argument("-s", "--suffix", type=str, action="append", default=[],
                    help="Only files with a given suffix are read")
parser.add_argument("--missing-as-zero", action="store_true",
                    help="Counts problems where no data file was generated for "
                         "a problem as 0 samples have been generated ("
                         "otherwise only problems where all suffixes have "
                         "generated data are counted)")
parser.add_argument("--dry", action="store_true",
                    help="Shows only the directories it would analyse")


PATTERN_PROBLEM = re.compile("(p(\d+))\.(.*)")
def get_problem_and_fold(problem):
    match = PATTERN_PROBLEM.match(problem)
    if match is None:
        return None, None
    fold = int((int(match.group(2)) - 1) / FOLDS_SIZE)
    if fold < 0 or fold > FOLDS_COUNTS:
        print("Invalid problem id for folds: %s" % problem)
        return None, None
    else:
        return match.group(1), fold


def get_suffix(data_set):
    idx = data_set.find(".")
    assert idx > -1
    return data_set[idx:]

def find_domains(path):
    path_domains = []
    todo = [path]
    while len(todo) > 0:
        path_dir = todo.pop()
        path_domain_file = os.path.join(path_dir, "domain.pddl")
        if os.path.isfile(path_domain_file):
            path_domains.append(path_dir)
        for item in os.listdir(path_dir):
            path_item = os.path.join(path_dir, item)
            if os.path.isdir(path_item):
                todo.append(path_item)
    return path_domains


def load_data_dir(dir_data, suffixes, missing_as_zero=False):
    """

    :param dir_data:
    :param suffixes:
    :param missing_as_zero:
    :return: {suffix: [total problem count, total sample count]},
             {suffix [[fold problem count, fold sample count], ...]}
    """
    path_stats = os.path.join(dir_data, FILE_STATS)
    suffix_2_counts = {suffix: [0, 0] for suffix in suffixes}
    suffix_folds_2_counts = {suffix: [[0, 0] for _ in range(FOLDS_COUNTS)]
                             for suffix in suffixes}
    at_least_one_problem = False

    if not os.path.exists(path_stats):
        return {}, {}

    with open(path_stats, "r") as f:
        all_stats = json.load(f)

    problem_fold_2_data_set = defaultdict(list)
    for data_set in all_stats.keys():
        problem, fold = get_problem_and_fold(data_set)
        if fold is None:
            continue
        problem_fold_2_data_set[(problem, fold)].append(data_set)

    for (problem, fold), data_sets in problem_fold_2_data_set.items():
        suffix_2_data_set = {}
        for data_set in data_sets:
            suffix = get_suffix(data_set)
            if suffix in suffixes:
                suffix_2_data_set[suffix] = data_set

        # Else ignore this problem
        if len(suffix_2_data_set) == len(suffixes) or missing_as_zero:
            for suffix in suffixes:
                if suffix in suffix_2_data_set:
                    at_least_one_problem = True
                    data_set = suffix_2_data_set[suffix]
                    stats = all_stats[data_set]
                    suffix_2_counts[suffix][0] += stats[KEY_PROBLEMS]
                    suffix_2_counts[suffix][1] += stats[KEY_SAMPLES]
                    suffix_folds_2_counts[suffix][fold][0] += stats[KEY_PROBLEMS]
                    suffix_folds_2_counts[suffix][fold][1] += stats[KEY_SAMPLES]

    if at_least_one_problem:
        return suffix_2_counts, suffix_folds_2_counts
    else:
        return {}, {}


def make_boxplot(file_out, data, title, folds):
    fig = plt.figure(figsize=(5 * (FOLDS_COUNTS if folds else 1), 5))
    for i in range(FOLDS_COUNTS if folds else 1):
        ax = fig.add_subplot(1, (FOLDS_COUNTS if folds else 1), i + 1)
        if folds:
            curr_data = {suffix: folds[i] for suffix, folds in data.items()}
        else:
            curr_data = data
        curr_data = sorted(curr_data.items(), key=lambda kv: np.median(kv[1]))
        values = [x[1] for x in curr_data]
        labels = ["%s\n(min %i, max %i)" % (x[0], min(x[1]), max(x[1]))
                  for x in curr_data]

        bp_dict = ax.boxplot(values, vert=False, labels=labels)
        for line in bp_dict['medians']:
            # get position data for median line
            x, y = line.get_xydata()[1]  # top of median line
            # overlay median value
            ax.text(x, y, '%.1f' % x, horizontalalignment='center')

    fig.tight_layout()
    fig.suptitle(title)
    fig.savefig\
        (file_out)

def parse_argv(argv):
    options = parser.parse_args(argv)
    for path_dir in options.directory:
        assert os.path.isdir(path_dir)
    return options


def run(argv):
    options = parse_argv(argv)

    domains = set()
    for d in options.directory:
        domains.update(find_domains(d))

    if options.dry:
        print(domains)
        sys.exit()

    suffix_2_problem_count = {suffix: [] for suffix in options.suffix}
    suffix_2_sample_count = {suffix: [] for suffix in options.suffix}
    suffix_fold_2_problem_count = {suffix: [[] for _ in range(FOLDS_COUNTS)] for suffix in options.suffix}
    suffix_fold_2_sample_count = {suffix: [[] for _ in range(FOLDS_COUNTS)] for suffix in options.suffix}


    for dir_domain in sorted(domains):
        print("Domain: %s" % dir_domain)

        new_suffix_2_counts, new_suffix_folds_2_counts = load_data_dir(
            dir_domain, options.suffix, options.missing_as_zero)

        for suffix, (prob_count, sample_count) in new_suffix_2_counts.items():
            suffix_2_problem_count[suffix].append(prob_count)
            suffix_2_sample_count[suffix].append(sample_count)

        for suffix, fold_counts in new_suffix_folds_2_counts.items():
            for fold in range(FOLDS_COUNTS):
                suffix_fold_2_problem_count[suffix][fold].append(fold_counts[fold][0])
                suffix_fold_2_sample_count[suffix][fold].append(fold_counts[fold][1])

    if len(domains) > -1:
        make_boxplot("_problem_counts.pdf", suffix_2_problem_count,
                     "Total problem counts", False)
        make_boxplot("_problem_fold_counts.pdf", suffix_fold_2_problem_count,
                     "Fold problem counts", True)

        make_boxplot("_sample_counts.pdf", suffix_2_sample_count,
                     "Total sample counts", False)
        make_boxplot("_sample_fold_counts.pdf", suffix_fold_2_sample_count,
                     "Fold sample counts", True)
    print("Done.")


if __name__ == "__main__":
    run(sys.argv[1:])
