#! /usr/bin/env python3

import argparse
from collections import defaultdict
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import re

import common
import config


PLOT_COVERAGE_EVOLUTION = "evolution"
PLOT_GEOMETRIC_MEANS = "geometric_means"

""" Input ArgumentParser """
parser = argparse.ArgumentParser()
parser.add_argument(
    "properties", type=common.type_is_file, nargs="+",
    help="Load Lab properties from all given paths")
parser.add_argument(
    "--algorithm-filter", "-af", type=re.compile, action="append", default=[],
    help="Keep only algorithms that satisfy ALL algorithm filters.")
parser.add_argument(
    "--paper-mode", action="store_true"
)
# parser.add_argument(
#     "--partial-properties", type=common.type_is_file, nargs="+", default=None,
#     help="Add property files for experiments that have not yet finished."
# )
parser.add_argument(
    "--attribute", default=common.PROPERTY_TOTAL_TIME,
    choices=[common.PROPERTY_TOTAL_TIME, common.PROPERTY_EXPANSIONS],
    help="The attribute to plot."
)
parser.add_argument(
    "--commonly-solved", action="store_true",
    help="Plot is done only over commonly solved")
parser.add_argument(
    "--minimum-coverage-fraction", type=float, default=None,
    help="Fraction of tasks that have to be solved at least by an algorithm"
         "(for a domain) for the algorithm to be used in the evaluation."
)
parser.add_argument(
    "--plot-type", default=PLOT_COVERAGE_EVOLUTION,
    choices=[PLOT_COVERAGE_EVOLUTION, PLOT_GEOMETRIC_MEANS],
    help="plot evolution of coverage for increasing value of the selected"
         "attribute or plot the (geometric) mean of the attribute for a domain")




""" Algorithm colors """
COLORS = ["g", "c", "m", "b", "r", "olive", "orange", "darkviolet"]
MARKERS = ["o", "x", "+", "D", "v", "s", "p", "H"]
NEXT_STYLE_INDEX = 0
STYLE_ASSIGNMENT = {}
FIX_STYLE = False


def get_style(algorithm):
    global NEXT_STYLE_INDEX
    if algorithm not in STYLE_ASSIGNMENT:
        assert not FIX_STYLE, algorithm
        if NEXT_STYLE_INDEX >= len(COLORS):
            assert False, STYLE_ASSIGNMENT.keys()
        STYLE_ASSIGNMENT[algorithm] = (NEXT_STYLE_INDEX, COLORS[NEXT_STYLE_INDEX], MARKERS[NEXT_STYLE_INDEX])
        NEXT_STYLE_INDEX += 1
    return STYLE_ASSIGNMENT[algorithm]


def get_color(algorithm):
    return get_style(algorithm)[1]


def get_marker(algorithm):
    return get_style(algorithm)[2]


# def plot_coverage_X_time(
#         file_out, domain, domain_data,
#         partial_data=None,
#         task_set=None,
#         xlog=False,
#         paper_mode=False):
#
#     fig = plt.figure()
#     ax = fig.add_subplot(111)
#     if not paper_mode:
#         ax.set_title(domain)
#         ax.set_xlabel("time in seconds")
#         ax.set_ylabel("coverage in %")
#
#     all_state_spaces = set()
#     all_tasks = set()
#     all_tasks_per_state_space = defaultdict(set)
#     algo_times = defaultdict(list)
#     for task, algos_results in domain_data.items():
#         if task_set is not None and os.path.join(domain, task) not in task_set:
#             continue
#         all_state_spaces.add(task)
#         for algo, results in algos_results.items():
#             all_tasks.update([os.path.join(task, task_id) for task_id in results[1]])
#             all_tasks_per_state_space[task].update([os.path.join(task, task_id) for task_id in results[1]])
#             algo_times[algo].extend(results[0].values())
#
#     max_time = 10 * 60 * 60
#     nb_markers = 2
#     max_x = max_time
#     if xlog:
#         max_log_x = np.log10(max_x)
#         marker_shift = max_log_x * 0.1
#         marker_times = np.array(
#             [max_log_x * (nm + 1) / (nb_markers + 1) for nm in
#              range(nb_markers)])
#     else:
#         marker_shift = max_x * 0.1
#         marker_times = np.array(
#             [(nm + 1) * max_x / (nb_markers + 1) for nm in range(nb_markers)])
#     if len(all_tasks) == 0:
#         return
#     if not any(k.lower().find("hgn") > -1 for k in algo_times.keys()):
#         algo_times["hgn"] = []
#     print(algo_times.keys())
#     for no_algo, (algo, times) in enumerate(algo_times.items()):
#         algo = config.rename_algorithm(algo)
#         assert len(times) <= len(all_tasks)
#         assert len(all_tasks) == len(all_state_spaces) * config.TASKS_PER_STATE_SPACE
#         ax.plot([0] + sorted(times) + [max_time],
#                 [0] +
#                 [100*(i + 1)/float(len(all_tasks)) for i in range(len(times))] +
#                 [100*len(times)/float(len(all_tasks))],
#                 color=get_color(algo))
#         individual_marker_times = marker_times + (no_algo - ((len(algo_times) - 1)/ 2.0)) * marker_shift
#         if xlog:
#             individual_marker_times = 10 ** individual_marker_times
#         ax.scatter(individual_marker_times,
#                    [100 * len([x for x in times if x <= mt])/float(len(all_tasks)) for mt in individual_marker_times],
#                    color=get_color(algo),
#                    marker=get_marker(algo),
#                    label=algo,
#                    s=15**2
#                    )
#
#     """ >>>>>>>>>>>>>>>> START HACK """
#     # if partial_data is not None:
#     #     algo_times = defaultdict(list)
#     #
#     #     for task, algos_results in partial_data.items():
#     #         if task_set is not None and os.path.join(domain, task) not in task_set:
#     #             continue
#     #         if task not in all_state_spaces:
#     #             continue
#     #         assert len(all_tasks_per_state_space[task]) == config.TASKS_PER_STATE_SPACE, f"{domain} {task} {len(all_tasks_per_state_space[task])} != {config.TASKS_PER_STATE_SPACE} {all_tasks_per_state_space.keys()}"
#     #         for algo, results in algos_results.items():
#     #             w = len(all_tasks_per_state_space[task])/len(results[1])
#     #             algo_times[algo].extend([(t, w) for t in results[0].values()])
#     #
#     #     if "hHGN" not in algo_times:
#     #         algo_times["hHGN"] = []
#     #     for algo, times in algo_times.items():
#     #         algo = config.rename_algorithm(algo)
#     #         assert len(times) <= len(all_tasks), f"{algo} {len(times)} {len(all_tasks)}"
#     #         assert len(all_tasks) == len(
#     #             all_state_spaces) * config.TASKS_PER_STATE_SPACE
#     #         times = sorted(times, key=lambda x: x[0])
#     #         y_values = [0]
#     #         for t, w in times:
#     #             y_values.append(w + y_values[-1])
#     #         y_values.append(y_values[-1])
#     #
#     #         ax.plot([0] + [t for t, _ in times] + [max_time],
#     #                 [100 * y / float(len(all_tasks)) for y in
#     #                  y_values],
#     #                 color=get_color(algo),
#     #                 label=algo)
#     #         ides = [[n for n, t in enumerate(times) if t[0] <= tm] for tm in [t_marker1, t_marker2]]
#     #         ides = [0 if len(l) == 0 else l[-1] for l in ides]
#     #         y_marker = [100 * y_values[i] / float(len(all_tasks)) for i in ides]
#     #         ax.scatter([max_time / 3.0, 2 * max_time / 3.0],
#     #                    y_marker,
#     #                    color=get_color(algo),
#     #                    marker=get_marker(algo),
#     #                    label=algo
#     #                    )
#     """ <<<<<<<<<<<<<<<< END HACK """
#     if paper_mode:
#         ax.tick_params(labelsize=20)
#     if not paper_mode:
#         ax.legend()
#     else:
#         # xticks = ax.get_xticks()
#         # ax.set_xticklabels([int(t) if n == 0 or n == len(xticks) - 2 else ""
#         #                for n, t in enumerate(xticks)])
#         x_end = 36000
#         ax.set_xticks([x_end])
#         ax.set_xticklabels(["%ss" % x_end])
#         for label in ax.xaxis.get_ticklabels():
#             label.set_horizontalalignment("right")
#         n = 6
#         ax.set_yticks([100.0 * i / (n-1) for i in range(n)])
#         ax.set_yticklabels([ round(100* i / (n-1)) for i in range(n)])
#     if xlog:
#         ax.set_xscale("log")
#     fig.tight_layout()
#     fig.savefig(file_out)
#     plt.close(fig)


def plot_coverage_X_something(
        file_out, domain,
        all_state_spaces, algo_solved, algo_attempted,
        xlabel, ylabel, max_x=None,
        task_set=None,
        paper_mode=False, xlog=False, fillX=False):

    if len(algo_solved) == 0:
        return

    all_tasks = set([task for tasks in algo_attempted.values() for task in tasks])

    if max_x is None:
        max_x = max(0 if len(v) == 0 else max(v) for v in algo_solved.values())
    nb_markers = 2

    if not any(k.lower().find("hgn") > -1 for k in algo_solved.keys()):
        algo_solved["hgn"] = []

    fig = plt.figure()
    ax = fig.add_subplot(111)
    if not paper_mode:
        ax.set_title(domain)
        ax.set_xlabel(xlabel)
        ax.set_ylabel(ylabel)

    for no_algo, (algo, times) in enumerate(algo_solved.items()):
        algo = config.rename_algorithm(algo)
        assert len(times) <= len(all_tasks)
        assert len(all_tasks) == len(all_state_spaces) * config.TASKS_PER_STATE_SPACE
        ax.plot([0] + sorted(times) + ([max_x] if fillX else []),
                [0] +
                [100*(i + 1)/float(len(all_tasks)) for i in range(len(times))] +
                ([100*len(times)/float(len(all_tasks))] if fillX else []),
                color=get_color(algo))
        if len(times) == 0:
            continue
        max_marker_x = max_x if fillX else max(times)
        max_marker_x = np.log10(max_marker_x) if xlog else max_marker_x
        marker_shift = max_marker_x * 0.1
        marker_times = np.array(
            [max_marker_x * (nm + 1) / (nb_markers + 1) for nm in
             range(nb_markers)])

        marker_times += (no_algo - ((len(algo_solved) - 1)/ 2.0)) * marker_shift
        if xlog:
            marker_times = 10 ** marker_times
        ax.scatter(marker_times,
                   [100 * len([x for x in times if x <= mt])/float(len(all_tasks)) for mt in marker_times],
                   color=get_color(algo),
                   marker=get_marker(algo),
                   label=algo,
                   s=15**2
                   )
    if paper_mode:
        ax.tick_params(labelsize=20)
    if not paper_mode:
        ax.legend()
    else:
        # xticks = ax.get_xticks()
        # ax.set_xticklabels([int(t) if n == 0 or n == len(xticks) - 2 else ""
        #                for n, t in enumerate(xticks)])
        x_ticks_hours = [i for i in range(int(max_x / 3600) + 1 )]
        ax.set_xticks([x * 3600 for x in x_ticks_hours])
        ax.set_xticklabels([str(x) for x in x_ticks_hours])
        # for label in ax.xaxis.get_ticklabels():
        #     label.set_horizontalalignment("right")
        # n = 6
        # ax.set_yticks([100.0 * i / (n-1) for i in range(n)])
        # ax.set_yticklabels([ round(100* i / (n-1)) for i in range(n)])
    if xlog:
        ax.set_xscale("log")
    ax.set_ylim([-1, 105])
    fig.tight_layout()
    fig.savefig(file_out)
    plt.close(fig)



# def plot_coverage_X_instance(file_out, domain, domain_data, task_set=None):
#     fig = plt.figure()
#     ax = fig.add_subplot(111)
#     ax.set_title(domain)
#     ax.set_xlabel("instance")
#     ax.set_ylabel("coverage in %")
#
#     all_tasks = defaultdict(set)
#     algo_task_cov = defaultdict(lambda: defaultdict(set))
#
#     for task, algos_results in domain_data.items():
#         if task_set is not None and os.path.join(domain, task) not in task_set:
#             continue
#         for algo, results in algos_results.items():
#             all_tasks[task].update(results[1])
#             algo_task_cov[algo][task] = results[0].keys()
#
#     task_order = common.natural_sort(all_tasks.keys())
#
#     nb_algos = len(algo_task_cov)
#     width = 1.0/(nb_algos + 2)
#     for no, (algo, task_cov) in enumerate(algo_task_cov.items()):
#         algo = config.rename_algorithm(algo)
#         values = []
#         for task in task_order:
#             values.append(len(task_cov[task])/float(len(all_tasks[task])) if task in task_cov else 0)
#         ax.plot(np.arange(len(task_order)), values, color=get_color(algo), linewidth=1,
#                 linestyle="--")
#         ax.scatter(
#             np.arange(len(task_order)), values,
#             color=get_color(algo),
#             label=algo)
#         # ax.bar(np.arange(len(task_order)) + (no - 0.5 * nb_algos + 0.5) * width,
#         #        values,
#         #        color=get_color(algo),
#         #        label=config.rename_algorithm(algo),
#         #        width=width)
#
#     ax.set_xticks(np.arange(len(task_order)))
#     ax.set_xticklabels(["" for _ in task_order])
#     ax.legend()
#     fig.tight_layout()
#     fig.savefig(file_out)
#     plt.close(fig)


def plot_legend(path):
    fig = plt.figure()
    ax = fig.add_subplot(111)

    sorted_algo = common.my_algo_sort(STYLE_ASSIGNMENT.keys())
    print("SORTED ALGO", sorted_algo)
    lines = []
    for algo in sorted_algo:
        lines.append(ax.scatter([0, 1], [0, 1],marker=get_marker(algo), color=get_color(algo), label=algo))
    figlegend = plt.figure(figsize=(8, 0.5))

    figlegend.legend(lines, [a for a in sorted_algo], 'center', ncol=len
    (sorted_algo))
    figlegend.tight_layout()
    figlegend.savefig(path)


def restrict_data(domain, domain_data, task_set, commonly_solved, minimum_coverage_fraction):
    def load(_common=None, _algos=None):
        _all_state_spaces = set()
        _algo_attributes = defaultdict(list)
        _algo_solved = defaultdict(list)
        _algo_attempted = defaultdict(list)
        for _task, _algos_results in domain_data.items():
            if task_set is not None and os.path.join(domain, _task) not in task_set:
                continue
            _all_state_spaces.add(_task)
            for _algo, _results in _algos_results.items():
                if _algos is not None and _algo not in _algos:
                    continue
                _tasks_solved, _tasks_attempted = _results
                _tasks_solved = {os.path.join(_task, _task_id): _attr
                                 for _task_id, _attr in _tasks_solved.items()}
                _tasks_attempted = [os.path.join(_task, _task_id) for _task_id in _tasks_attempted]

                if _common is not None:
                    _tasks_solved = {
                        _task: _attr
                        for _task, _attr in _tasks_solved.items()
                        if _task in _common
                    }
                    _tasks_attempted = [_task for _task in _tasks_attempted
                                        if _task in _common]

                _algo_attempted[_algo].extend(_tasks_attempted)
                _algo_solved[_algo].extend(_tasks_solved.keys())
                _algo_attributes[_algo].extend(_tasks_solved.values())
        return _all_state_spaces, _algo_solved, _algo_attributes, _algo_attempted

    all_state_spaces, algo_solved, algo_attributes, algo_attempted = load()
    if minimum_coverage_fraction is not None:
        rm_algos = []
        for algo, attrs in algo_solved.items():
            if len(attrs)/(len(all_state_spaces) * config.TASKS_PER_STATE_SPACE) < minimum_coverage_fraction:
                rm_algos.append(algo)
        for algo in rm_algos:
            del algo_solved[algo]
            del algo_attempted[algo]

    if commonly_solved:
        common = None
        for algo, tasks in algo_solved.items():
            common = set(tasks) if common is None else (common & set(tasks))
        all_state_spaces, algo_solved, algo_attributes, algo_attempted = load(common, algo_solved.keys())
    return all_state_spaces, algo_attributes, algo_attempted


def geo_mean_overflow(iterable):
    iterable = [max(1, x) for x in iterable]
    a = np.log(iterable)
    return np.exp(a.mean())

def plot_geometric_means(outfile, algo_solved, paper_mode):
    if len(algo_solved) == 0:
        return

    fig = plt.figure()
    ax = fig.add_subplot(111)
    # height = 0.8
    # for algo, solved in algo_solved.items():
    #     algo = config.rename_algorithm(algo)
    #     style = get_style(algo)
    #     ax.barh(y=NEXT_STYLE_INDEX-style[0] - 1, width=geo_mean_overflow(solved), height=height, color=style[1])
    #
    data = defaultdict(list)
    for k, v in algo_solved.items():
        data[config.rename_algorithm(k)] = v
    algos = common.my_algo_sort(STYLE_ASSIGNMENT.keys())[::-1]
    for a in data.keys():
        assert config.rename_algorithm(a) in STYLE_ASSIGNMENT, a

    handles = ax.boxplot(
        [data[a] for a in algos],
        labels=algos,
        vert=False,
        showfliers=False
    )

    width = 4
    for n, width, element in [(1, width, "boxes"),
                              (2, width, "whiskers"),
                              (2, width, "caps"),
                              (1, 1.5*width, "medians")]:

        for a, h in zip([a for a in algos for i in range(n)], handles[element]):
            if element != "medians":
                c = get_color(config.rename_algorithm(a))
            else:
                c = "black"

            h.set_color(c)
            h.set_linewidth(width)
    ax.set_xscale("log")

    if paper_mode:
        ax.set_yticks([i for i in range(NEXT_STYLE_INDEX)])
        ax.set_yticklabels(["" for _ in range(NEXT_STYLE_INDEX)])

    # ax.set_ylim([-0.5, NEXT_STYLE_INDEX - 0.5])
    # ax.tick_params(axis='x', which="major", labelsize=24)
    # ax.tick_params(axis='x', which="minor", labelsize=24)

    if paper_mode:
        ax.tick_params(labelsize=20)

    fig.tight_layout()
    fig.savefig(outfile)
    plt.close(fig)

def run(options):
    global FIX_STYLE
    data = common.load_properties(
        options.properties, options.algorithm_filter,
        property_name=options.attribute)
    data = common.reduce_retraining_iterations(data)
    # partial_data = common.load_properties(options.partial_properties, options.algorithm_filter)
    # partial_data = common.reduce_retraining_iterations(partial_data)
    # if partial_data is None:
    #     partial_data = defaultdict(lambda: None)

    all_algorithms = set()
    for domain_data in data.values():
        for algo_results in domain_data.values():
            all_algorithms.update([config.rename_algorithm(a) for a in algo_results.keys()])
    for algo in common.my_algo_sort(all_algorithms):
        get_style(algo)
    FIX_STYLE = True

    print(f"Plot {options.attribute} {options.plot_type}")
    for domain, domain_data in data.items():
        print(f"\t{domain}")
        for task_set_name, task_set in [("all", config.ALL_TASKS),
                                        ("moderate", config.MODERATE_TASKS),
                                        ("hard", config.HARD_TASKS)]:
            curr_state_spaces, curr_algo_solved, curr_algo_attempted = restrict_data(
                domain, domain_data, task_set, options.commonly_solved,
                options.minimum_coverage_fraction)

            if options.plot_type == PLOT_COVERAGE_EVOLUTION:
                plot_coverage_X_something(
                    f"plots/Cx{common.ATTRIBUTE_ABBREVIATION[options.attribute]}"
                    f"_{task_set_name}_{domain}.pdf", domain,
                    curr_state_spaces, curr_algo_solved, curr_algo_attempted,
                    options.attribute, "coverage",
                    task_set=task_set, paper_mode=options.paper_mode,
                    xlog=False, fillX=True, max_x=36000
                )
            elif options.plot_type == PLOT_GEOMETRIC_MEANS:
                plot_geometric_means(
                    f"plots/GM_{common.ATTRIBUTE_ABBREVIATION[options.attribute]}"
                    f"_{task_set_name}_{domain}.pdf",
                    curr_algo_solved,
                    paper_mode=options.paper_mode
                )
    plot_legend("plots/legend.pdf")


if __name__ == "__main__":
    run(parser.parse_args(sys.argv[1:]))
