import argparse
import json
import os
import re
import sys


import matplotlib.pyplot as plt
import numpy as np

import common
import config


parser = argparse.ArgumentParser()
parser.add_argument("input", type=common.type_is_file)
parser.add_argument("--suffix", default=".generator.plan.sat.data.gz")
parser.add_argument("--filter", type=re.compile,
                    help="filter which data sets to consider")
parser.add_argument(
    "--per-hour-divisor", type=float, default=None,
    help="Use this value to scale the data generation to samples per hour (e.g."
         " if you sampled 2 hours per data batch and have 3 batches, divide by "
         "6).")
parser.add_argument("--paper-mode", action="store_true",
                    help="modify plots for paper")
parser.add_argument("--output", default="plots/data_set_sizes_{}.pdf")

parser.add_argument(
    "--properties", type=common.type_is_file, nargs="+", default=None,
    help="Load Lab properties from all given paths. Used to overlay the coverage"
         "of an algorithm on the tasks. Only a single algorithm may be selected"
         "from those files.")
parser.add_argument(
    "--algorithm-filter", "-af", type=re.compile, action="append", default=[],
    help="Keep only algorithms that satisfy ALL algorithm filters. Only a "
         "single algorithm may remain.")

COLORS = ["blue", "orange"]
MARKERS = ["s", "^"]


def make_figure(domain, domain_sizes, suffix, filter, task_sets,per_hour_divisor, coverages, path_out, paper_mode):
    new_domain_sizes = {}
    for k, v in domain_sizes.items():
        new_domain_sizes[config.rename_task(k)] = v
    domain_sizes = new_domain_sizes

    task2set = {}
    sets = []
    for n, (tname, ttasks) in enumerate(task_sets.items()):
        if tname == "all":
            continue
        sets.append(tname)
        for tt in ttasks:
            if domain != os.path.dirname(tt):
                continue
            tt = os.path.splitext(os.path.basename(tt))[0]
            task2set[tt] = n

    fig = plt.figure()
    ax = fig.add_subplot(111)
    if not paper_mode:
        ax.set_title(domain)
        ax.set_ylabel("# Samples" + ("" if per_hour_divisor is None else "/hour"))
    task_order = common.natural_sort(task2set.keys())
    size = []
    set_assignment = []
    cov_values = None if coverages is None else []
    for task in task_order:
        set_assignment.append(task2set[task])

        if task in domain_sizes:
            task_sizes = domain_sizes[task]
            data_keys = [x for x in task_sizes.keys()
                         if x.endswith(suffix) and
                         (filter is None or filter.match(x))]
            # if len(data_keys) != 4: print(">>", len(data_keys), domain, task)
            s = sum([task_sizes[dk]["#problems"] for dk in data_keys])
            if per_hour_divisor is not None:
                s /= per_hour_divisor
            size.append(s)
        else:
            size.append(0)

        if cov_values is not None:
            if task + ".pddl" in coverages:
                algos_results = coverages[task + ".pddl"]
                assert len(algos_results) == 1, f"{task} {len(algos_results)}"
                for r in algos_results.values():
                    pass
                cov_values.append(len(r[0])/config.TASKS_PER_STATE_SPACE)
            else:
                assert sets[set_assignment[-1]] == "hard" or domain == "storage" or task=="p20", f"{domain} {task}"
                cov_values.append(0)

    if cov_values is not None:
        print(">>", cov_values)
        assert len(cov_values) == len(size)

        ax_cov = ax.twinx()
        ax_cov.bar(np.arange(len(cov_values)), cov_values, width=0.8, color="gray", alpha=0.5)
    ax.plot(size, color="black", linestyle="--", linewidth=1)

    for no_set in range(len(sets)):
        indices = [n for n in range(len(size)) if set_assignment[n] == no_set]
        speeds = [size[n] for n in indices]
        ax.scatter(indices, speeds, s=200, marker=MARKERS[no_set], color=COLORS[no_set], label=sets[no_set])
        print(f"\t{sets[no_set]}: {min(speeds)} - {np.mean(speeds)} - {max(speeds)}")
    ax.set_xticks([i for i in range(len(task_order))])
    ax.set_xticklabels(["" if paper_mode else t for t in task_order], rotation="vertical")
    ax.tick_params(labelsize=20)
    fig.tight_layout()
    fig.savefig(path_out)
    plt.close(fig)


def run(options):
    assert options.output.find("{}") > -1, options.output
    with open(options.input, "r") as f:
        all_sizes = json.load(f)

    data_coverage = None
    if options.properties is not None:
        data_coverage = common.load_properties(options.properties, options.algorithm_filter)
        data_coverage = common.reduce_retraining_iterations(data_coverage)

    for domain in common.natural_sort(list(all_sizes.keys())):
        print(f"Domain: {domain}")
        domain_coverage = None if data_coverage is None else data_coverage[domain]
        make_figure(domain, all_sizes[domain],
                    suffix=options.suffix,
                    filter=options.filter,
                    task_sets=config.TASK_SETS,
                    per_hour_divisor=options.per_hour_divisor,
                    path_out=options.output.format(domain),
                    coverages=domain_coverage,
                    paper_mode=options.paper_mode)



if __name__ == "__main__":
    run(parser.parse_args(sys.argv[1:]))
