#!/usr/bin/env python

import argparse
import collections
import json
import numpy as np
import matplotlib.pyplot as plt
import os
import re
import sys

PATH_DATA = os.path.join("experiments", "data")


def get_setting_prune_solver_sampletype(algorithm, **kwargs):
    if algorithm.find("_MK") > -1:
        return None
    if algorithm.find("_pruneOff") > -1:
        prune = "Prune Off"
    else:
        prune = "Prune On"

    if algorithm.find("_opt_") > -1:
        solver = "optimal"
    elif algorithm.find("_sat_") > -1:
        solver = "satisficing"
    else:
        assert False
    if algorithm.find("_init_") > -1:
        sample_type = "initial states"
    elif algorithm.find("_inter_") > -1:
        sample_type = "intermediate states"
    elif algorithm.find("_plan_") > -1:
        sample_type = "full plan"
    else:
        assert False
    return "%s %s %s" % (prune, solver, sample_type)


def get_setting_generator_ugenerator(algorithm, baseline_algorithm, **kwargs):
    if algorithm.find("_gen_") > -1:
        return "generator"
    elif algorithm.find("_ugen_") > -1:
        return "uniform generator"
    elif algorithm == baseline_algorithm:
        return baseline_algorithm
    else:
        assert False


def get_setting_identity(algorithm, **kwargs):
    return algorithm


SETTINGS = {
    "identity": get_setting_identity,
    "prune_solver_sampletype": get_setting_prune_solver_sampletype,
    "generator_ugenerator": get_setting_generator_ugenerator
}


PATTERN_DATA_AMOUNT = re.compile(r".*K(all|\d+(e\d)?\d*)")
def get_data_amount(algorithm, default=None):
    match_data_amount = PATTERN_DATA_AMOUNT.match(algorithm)
    if match_data_amount is None:
        assert default is not None
        return default
    else:
        data_amount = match_data_amount.group(1)
        if data_amount == "all":
            return float("inf")
        else:
            data_amount = float(data_amount)
            return int(data_amount)


""" Method to load the required data from properties files """
def load_data_template(
        path_dir, get_setting,
        factory_initialize,
        data_aggregator,
        regexes_eval_dir=[],
        regexes_algorithm=[],
        regexes_domain=[],
        baseline_algorithm=None,
        default_data_amount=None,
        previous_data=None):

    # {Domain: {Algorithm Type : {Data size : CUSTOM DATA}}}
    data = {} if previous_data is None else previous_data

    for item in os.listdir(path_dir):
        path_item = os.path.join(path_dir, item)
        path_properties = os.path.join(path_item, "properties")
        if (all(r.match(item) for r in regexes_eval_dir) and
                item.endswith("-eval") and
                os.path.isfile(path_properties)):
            print(item)

            with open(path_properties, "r") as f:
                prop = json.load(f)

            for run in prop.values():
                domain = run["domain"]
                algorithm = run["algorithm"]
                if not all(r.match(algorithm) for r in regexes_algorithm):
                    continue
                if not all(r.match(domain) for r in regexes_domain):
                    continue

                algorithm_setting = get_setting(
                    algorithm, baseline_algorithm=baseline_algorithm)
                data_amount = get_data_amount(algorithm, default_data_amount)
                if algorithm_setting is None:
                    continue

                if domain not in data:
                    data[domain] = {}
                if algorithm_setting not in data[domain]:
                    data[domain][algorithm_setting] = {}
                if data_amount not in data[domain][algorithm_setting]:
                    data[domain][algorithm_setting][data_amount] = factory_initialize()
                data[domain][algorithm_setting][data_amount] = data_aggregator(
                    data[domain][algorithm_setting][data_amount],
                    run)
    return data


def load_data_coverage(path_dir, get_setting,
                       regexes_eval_dir=[],
                       regexes_algorithm=[],
                       regexes_domain=[]):
    def data_aggregator(container, run):
        if run["coverage"]:
            container.add(run["problem"])
        return container

    return load_data_template(
        path_dir=path_dir,
        get_setting=get_setting,
        factory_initialize=lambda: set(),
        data_aggregator=data_aggregator,
        default_data_amount=float("inf"),
        regexes_eval_dir=regexes_eval_dir,
        regexes_algorithm=regexes_algorithm,
        regexes_domain=regexes_domain)
    return data


def load_data_expansions(path_dir, get_setting, regexes_eval_dir=[],
                         default_data_amount=None, baseline_algorithm=None):
    def data_aggregator(container, run):
        if run["coverage"]:
            container[run["problem"]] = run["expansions"]
        return container

    return load_data_template(
        path_dir=path_dir,
        get_setting=get_setting,
        factory_initialize=lambda: {},
        data_aggregator=data_aggregator,
        regexes_eval_dir=regexes_eval_dir,
        baseline_algorithm=baseline_algorithm,
        default_data_amount=default_data_amount)
    return data


def load_data_time(path_dir, get_setting,
                   regexes_eval_dir=[], regexes_algorithm=[], regexes_domain=[],
                   default_data_amount=None):
    def data_aggregator(container, run):
        if run["coverage"]:
            container.append(run["total_time"])
        return container

    return load_data_template(
        path_dir=path_dir,
        get_setting=get_setting,
        factory_initialize=lambda: [],
        data_aggregator=data_aggregator,
        regexes_eval_dir=regexes_eval_dir,
        regexes_algorithm=regexes_algorithm,
        regexes_domain=regexes_domain,
        default_data_amount=default_data_amount)
    return data


""" Method to make the coverage csv table """
def make_csv(data, path_csv, path_diff, diff_size=None):
    """Make Coverage CSV"""
    nb_rows = 0
    cols = set()
    algorithms = set()

    size2problems = collections.defaultdict(list)

    for domain, algo_size_cov in data.items():
        nb_rows += 4
        for algo_type, size_cov in algo_size_cov.items():
            algorithms.add(algo_type)
            nb_rows += 1
            cols.update(size_cov.keys())
            for size, cov in size_cov.items():
                size2problems[size].extend((domain, p) for p in cov)

    assert diff_size is None or diff_size in cols
    algorithms = {algo: no for no, algo in enumerate(sorted(algorithms))}

    # Get set of problems solved by multiple algorithms
    for size in size2problems.keys():
        problem2counts = collections.defaultdict(int)
        for p in size2problems[size]:
            problem2counts[p] += 1
        size2problems[size] = set(
            [p for p, c in problem2counts.items() if c > 1])

    nb_cols = len(cols) + 1
    cols = {col: idx_col + 1 for idx_col, col in enumerate(sorted(cols, reverse=True))}

    total_coverage = np.ndarray(shape=(nb_rows, nb_cols), dtype=object)
    total_coverage[:, :] = ""
    diff_coverage = np.ndarray(shape=(nb_rows, nb_cols), dtype=object)
    diff_coverage[:, :] = ""

    diff_size_coverage = np.ndarray(shape=(len(data) + 1, len(algorithms) + 1),
                                    dtype=object)
    diff_size_coverage[:, :] = ""
    diff_size_coverage[0, 0] = "Domains"
    for algo, no_algo in algorithms.items():
        diff_size_coverage[0, no_algo + 1] = algo

    idx_row = 0
    for no_domain, (domain, algo_size_cov) in enumerate(sorted(data.items())):
        total_coverage[idx_row, 0] = domain
        diff_coverage[idx_row, 0] = domain
        diff_size_coverage[no_domain + 1, 0] = domain

        idx_row += 1
        for size, idx_col in cols.items():
            total_coverage[idx_row, idx_col] = size
            diff_coverage[idx_row, idx_col] = size
        idx_row += 1
        for algo_type, size_cov in sorted(algo_size_cov.items()):
            total_coverage[idx_row, 0] = algo_type
            diff_coverage[idx_row, 0] = algo_type
            for size, coverage in sorted(size_cov.items()):
                idx_col = cols[size]
                unique_coverage = "%s(%s)" %(
                    len(set((domain, p) for p in coverage) - size2problems[size]),
                    len(coverage))

                total_coverage[idx_row, idx_col] = len(coverage)
                diff_coverage[idx_row, idx_col] = unique_coverage
                if diff_size is not None and diff_size == size:
                    diff_size_coverage[
                        no_domain + 1, algorithms[algo_type] + 1] = unique_coverage

            idx_row += 1
        idx_row += 2
    np.savetxt(path_csv, total_coverage, fmt="%s", delimiter=";")

    if diff_size is None:
        np.savetxt(path_diff, diff_coverage, fmt="%s", delimiter=";")
    else:
        np.savetxt(path_diff, diff_size_coverage, fmt="%s", delimiter=";")




""" Make expansion comparison table """
def make_expansion_specific_csv(data, baseline, path_csv):
    nb_stats = 3
    domain_order = []
    algorithms_and_sizes = collections.defaultdict(set)
    for domain in sorted(data.keys()):
        algo_size_cov = data[domain]
        if baseline not in algo_size_cov:
            continue
        domain_order.append(domain)

        for algo_type, size_cov in algo_size_cov.items():
            algorithms_and_sizes[algo_type].update(size_cov.keys())
    nb_algorithms_with_sizes = sum(len(x) for x in algorithms_and_sizes.values())

    nb_rows = len(domain_order) + 2
    nb_cols = nb_algorithms_with_sizes * nb_stats + 2
    ary = np.ndarray(shape=(nb_rows, nb_cols), dtype=object)
    ary[:, :] = ""
    ary[0, 0] = "Domain"
    ary[0, 1] = baseline

    algorithm_order = []
    for algo in sorted(algorithms_and_sizes.keys()):
        if algo == baseline:
            continue
        sizes = algorithms_and_sizes[algo]
        for size in sorted(sizes):
            algorithm_order.append((algo, size))
    for no, (algorithm, size) in enumerate(algorithm_order):
        ary[0, nb_stats * no + 2] = "%s-%s" % (algorithm, size)
        ary[1, nb_stats * no + 2] = "<="
        ary[1, nb_stats * no + 3] = ">"
        ary[1, nb_stats * no + 4] = "TOTAL"

    for no_domain, domain in enumerate(domain_order):
        algo_size_cov = data[domain]
        ary[no_domain + 2, 0] = domain
        assert baseline in algo_size_cov
        assert "all" in algo_size_cov[baseline]
        base_expansions = algo_size_cov[baseline]["all"]
        ary[no_domain + 2, 1] = len(base_expansions)

        for no_algorithm, (algo, size) in enumerate(algorithm_order):
            if algo not in algo_size_cov or size not in algo_size_cov[algo]:
                continue
            problem_expansions = algo_size_cov[algo][size]
            less, more = 0, 0

            for problem, expansions in problem_expansions.items():
                if problem in base_expansions:
                    if expansions <= base_expansions[problem]:
                        less += 1
                    else:
                        more += 1

            ary[no_domain + 2, nb_stats * no_algorithm + 2] = less
            ary[no_domain + 2, nb_stats * no_algorithm + 3] = more
            ary[no_domain + 2, nb_stats * no_algorithm + 4] = len(problem_expansions)

    np.savetxt(path_csv, ary, fmt="%s", delimiter=";")


""" Plot coverage time relation"""
def make_coverage_time_relation(data, path_plot):


    algo2times = collections.defaultdict(list)
    fig = plt.figure(figsize=(10, (len(data) + 1) * 5))

    def plot_ax(ax, domain, all_times, all_labels,
                timepoints=[10, 60, 300, 900, 1800]):
        ary = np.ndarray(shape=(len(all_labels) + 3, len(timepoints) + 1),
                         dtype=object)
        ary[:, :] = ""
        ary[0, 0] = domain
        ary[1, 0] = "timepoints in s"
        for no, label in enumerate(all_labels):
            ary[no + 2, 0] = label

        for no, timepoint in enumerate(timepoints):
            ary[1, no + 1] = timepoint

        for times, label in zip(all_times, all_labels):
            ax.plot(times, [i + 1 for i in range(len(times))], label=label + " (%i)" % len(times))
        ax.legend()
        ax.set_title("Coverage for %s" % domain)
        ax.set_xlabel("time in s")
        ax.set_ylabel("#solved problems")

        previous_coverage = collections.defaultdict(int)
        for no_tp, timepoint in enumerate(timepoints):
            ax.axvline(timepoint, color="black", alpha=0.3)
            for no_t, times in enumerate(all_times):
                cov = len([t for t in times if t <= timepoint])
                if cov > previous_coverage[no_t]:
                    ax.text(timepoint, cov, "%i" % cov)
                    previous_coverage[no_t] = cov
                ary[no_t + 2, no_tp + 1] = cov
        return ary


    coverage_arys = []
    for no, (domain, algo_datasize_times) in enumerate(sorted(data.items())):
        ax = fig.add_subplot(len(data) + 1, 1, no + 2)

        all_times = []
        all_labels = []
        for algo, datasize_times in sorted(algo_datasize_times.items()):
            for datasize, times in sorted(datasize_times.items()):
                times = sorted(times)
                label = "%s-%s" % (algo, datasize)

                all_times.append(times)
                all_labels.append(label)
                algo2times[label].extend(times)
        coverage_arys.append(plot_ax(ax, domain, all_times, all_labels))

    ax = fig.add_subplot(len(data) + 1, 1, 1)
    all_times = [sorted(times) for algo, times in sorted(algo2times.items())]
    coverage_arys.insert(
        0, plot_ax(ax, "all domains", all_times, sorted(algo2times.keys())))

    fig.tight_layout()
    fig.savefig(path_plot)

    path_csv = "%s.csv" % (path_plot if not path_plot.endswith(".pdf") else path_plot[:-4])
    for ary in coverage_arys:
        print(ary.shape)
    np.savetxt(path_csv, np.concatenate(coverage_arys, axis=0), fmt="%s", delimiter=";")


parser_coverage_table = argparse.ArgumentParser(
    "Creates a coverages csv table")

parser_coverage_table.add_argument(
    "setting", choices=SETTINGS.keys(), action="store",
    help="Select method to extract setting (e.g. pntermediate samples,"
         "plan states) from algorithm name in properties file")
parser_coverage_table.add_argument(
    "--regex-data", type=str, action="append", default=[],
    help="Regex to filter which property file directories to process")

parser_coverage_table.add_argument(
    "--regex-algorithm", type=str, action="append", default=[],
    help="Regex to filter which algorithms to consider")

parser_coverage_table.add_argument(
    "--regex-domain", type=str, action="append", default=[],
    help="Regex to filter which domains to consider")

parser_coverage_table.add_argument(
    "--output", type=str, action="store", default="coverage.csv",
    help="Path to the location for storing the output csv file.")

parser_coverage_table.add_argument(
    "--diff-output", type=str, action="store", default="coverage_diff.csv",
    help="Path to the location for storing the output csv file showing the"
         "problems solved by only one algorithm")

parser_coverage_table.add_argument(
    "--diff-size", type=float, action="store", default=None,
    help="if set, the difference coverage table is not plotted for all sizes, "
         "but only the specified size")


parser_expansions = argparse.ArgumentParser(
    "Creates coverage comparison of algorithm when using the same number of"
    "expansions as a baseline")

parser_expansions.add_argument(
    "setting", choices=SETTINGS.keys(), action="store",
    help="Select method to extract setting (e.g. pntermediate samples,"
         "plan states) from algorithm name in properties file")
parser_expansions.add_argument(
    "baseline", type=str, action="store",
    help="Name of the baseline algorithm")
parser_expansions.add_argument(
    "--regex", type=str, action="append", default=[],
    help="Regex to filter which property file directories to process")
parser_expansions.add_argument(
    "--output", type=str, action="store", default="output.csv",
    help="Path to the location for storing the output csv file.")


parser_coverage_time = argparse.ArgumentParser(
    "Creates a plot comparing the coverage of different algorithms"
    "depending on the time run.")

parser_coverage_time.add_argument(
    "--regex-data", type=str, action="append", default=[],
    help="Regex to filter which property file directories to process")

parser_coverage_time.add_argument(
    "--regex-algorithm", type=str, action="append", default=[],
    help="Regex to filter which algorithms to consider")

parser_coverage_time.add_argument(
    "--regex-domain", type=str, action="append", default=[],
    help="Regex to filter which domains to consider")

parser_coverage_time.add_argument(
    "--regex-baseline", type=str, action="append", default=[],
    help="Regex to detect which algorithms are marked as baseline")

parser_coverage_time.add_argument(
    "--output", type=str, action="store", default="output.csv",
    help="Path to the location for storing the output csv file.")


def run_coverage(argv):
    options = parser_coverage_table.parse_args(argv)
    options.setting = SETTINGS.get(options.setting)

    for no, regex in enumerate(options.regex_data):
        options.regex_data[no] = re.compile(regex)

    for no, regex in enumerate(options.regex_algorithm):
        options.regex_algorithm[no] = re.compile(regex)

    for no, regex in enumerate(options.regex_domain):
        options.regex_domain[no] = re.compile(regex)

    data = load_data_coverage(
        path_dir=PATH_DATA,
        get_setting = options.setting,
        regexes_eval_dir=options.regex_data,
        regexes_algorithm=options.regex_algorithm,
        regexes_domain=options.regex_domain
    )
    make_csv(data, options.output, options.diff_output, options.diff_size)


def run_expansions(argv):
    options = parser_expansions.parse_args(argv)
    options.setting = SETTINGS.get(options.setting)

    for no, regex in enumerate(options.regex):
        options.regex[no] = re.compile(regex)

    data = load_data_expansions(
        path_dir=PATH_DATA,
        get_setting=options.setting,
        regexes_eval_dir=options.regex,
        default_data_amount="all",
        baseline_algorithm=options.baseline)
    make_expansion_specific_csv(data, options.baseline, options.output)


def run_time(argv):
    options = parser_coverage_time.parse_args(argv)
    for no, regex in enumerate(options.regex_data):
        options.regex_data[no] = re.compile(regex)

    for no, regex in enumerate(options.regex_algorithm):
        options.regex_algorithm[no] = re.compile(regex)

    for no, regex in enumerate(options.regex_domain):
        options.regex_domain[no] = re.compile(regex)

    for no, regex in enumerate(options.regex_baseline):
        options.regex_baseline[no] = re.compile(regex)

    data = load_data_time(
        path_dir=PATH_DATA,
        get_setting=get_setting_identity,
        regexes_eval_dir=options.regex_data,
        regexes_algorithm=options.regex_algorithm,
        regexes_domain=options.regex_domain,
        default_data_amount="NA")
    make_coverage_time_relation(data, options.output)


PARSERS = {
    "+coverage": (parser_coverage_table, run_coverage),

    "+expansions": (parser_expansions, run_expansions),
    "+time": (parser_coverage_time, run_time)
}

if __name__ == "__main__":
    if "-h" in sys.argv or "--help" in sys.argv or len(sys.argv) == 1:
        print("Multiple executions can be done in one call. The arguments will "
              "be split at the words %s. And the arguments behind it send to "
              "the associated pipeline." % ", ".join(PARSERS.keys()))
        for key, parser in PARSERS.items():
            print("Use: %s" % key)
            parser[0].print_help()
        print("Available SETTING NAME EXTRACTORs: %s"
              % ", ".join(SETTINGS.keys()))

    runs = []
    run = []
    method = None
    for arg in sys.argv[1:]:
        if arg in PARSERS.keys():
            if len(run) > 0:
                assert method is not None, "Arguments not matched to a pipeline"
                runs.append((method, run))
                run = []
            method = PARSERS[arg][1]
        else:
            run.append(arg)
    if len(run) > 0:
        assert method is not None, "Arguments not matched to a pipeline"
        runs.append((method, run))
        run = []
        method = None

    for run_method, run_args in runs:
        run_method(run_args)
