#!/usr/bin/env python3
import collections
import matplotlib.pyplot as plt
import numpy as np
import re

PATTERN_EXPAND_NUMBERS = re.compile(r".*K(\d+)e(\d+)")
#HACK
def _convert_label(label):
    m = PATTERN_EXPAND_NUMBERS.match(label)
    if m:
        expanded = int(m.group(1)) * 10 ** int(m.group(2))
        return label[:-(len(m.group(1)) + len(m.group(2)) + 1)] + str(expanded)
    return label

def default_dict_int():
    return collections.defaultdict(int)


def get_relevant_universes(structured, algorithms, baseline_algorithms):
    relevant = collections.defaultdict(default_dict_int)
    for algorithm in algorithms:
        if (algorithm not in structured or
                (baseline_algorithms is not None and algorithm in baseline_algorithms)):
            continue
        for domain, struct_universe in structured[algorithm].items():
            for universe in struct_universe.keys():
                relevant[domain][universe] += 1
    for domain in relevant.keys():
        for problem in relevant[domain].keys():
            if relevant[domain][problem] < 1:
                del relevant[domain][problem]
        if len(relevant[domain]) == 0:
            del relevant[domain]
    return relevant

EXPANSIONS = "expansions"
PROBLEM = "problem"
SEARCH_TIME = "search_time"
def plot_ax(ax, baseline_algorithms, data, labels):
    idx_sort = np.argsort([np.median(x) for x in data])
    data = np.array(data)
    labels = np.array(labels)
    data = [x for x in data[idx_sort]]
    labels = labels[idx_sort]
    labels = [_convert_label(x) for x in labels]

    bp_dict = ax.boxplot(data, labels=labels, whis=[5, 95], vert=False, showfliers=False)
    for line in bp_dict['medians']:
        # get position data for median line
        x, y = line.get_xydata()[1]  # top of median line
        # overlay median value
        ax.text(x, y, '%.1f' % x, horizontalalignment='center')

    if baseline_algorithms is not None:
        for baseline_algorithm in baseline_algorithms:
            idx_baseline = np.where(labels == baseline_algorithm)
            assert len(idx_baseline[0]) <= 1
            if len(idx_baseline[0]) > 0:
                ax.axvline(np.median(data[idx_baseline[0][0]]))


def plot_speed_ax(ax, baseline_algorithms, data, labels):
    idx_sort = np.argsort(labels)
    data = np.array(data)
    labels = np.array(labels)
    data = [x for x in data[idx_sort]]
    labels = labels[idx_sort]
    labels = [_convert_label(x) for x in labels]
    polys = []
    max_x_value = []
    for no in range(len(labels)):
        x, y = zip(*data[no])
        poly = np.polynomial.polynomial.Polynomial.fit(x, y, 1)
        polys.append(poly.convert().coef)
        psteps = 100
        xp = [poly.domain[0] + i*(poly.domain[1]-poly.domain[0])/psteps for i in range(psteps)]
        yp = [poly(i) for i in xp]
        label = labels[no] + "\n%s" % "+".join(("%.1f" % coef) if deg == 0 else ("%.1fx^%i" % (coef, deg)) for deg, coef in enumerate(poly.convert().coef) if np.abs(coef) > 0.01)
        ax.plot(xp, yp)
        ax.scatter(x, y, label=label, marker="x")
        max_x_value.append(max(x))
    ax.legend()
    return labels, polys, max_x_value

def get_current_data_and_labels(structured, domain, universe, algorithms,
                                baseline_algorithms, nb_columns=None,
                                column_assigner=None, common_problems=True,
                                add_time=False):
    common_problems_set = None
    if common_problems:
        for algorithm in algorithms:
            if (algorithm in structured
                    and domain in structured[algorithm]
                    and universe in structured[algorithm][domain]):

                my_problems = set()
                for problem_properties in structured[algorithm][domain][
                    universe].values():
                    if EXPANSIONS in problem_properties:
                        my_problems.add(problem_properties[PROBLEM])

                if common_problems_set is None:
                    common_problems_set = my_problems
                else:
                    common_problems_set = common_problems_set & my_problems

    all_expansions = []
    for algorithm in algorithms:
        if (algorithm not in structured
                or domain not in structured[algorithm]
                or universe not in structured[algorithm][domain]):
            all_expansions.append([])
        else:
            expansions = []
            for problem_properties in structured[algorithm][domain][
                universe].values():
                if (EXPANSIONS in problem_properties
                        and SEARCH_TIME in problem_properties
                        and (common_problems_set is None
                             or problem_properties[PROBLEM] in common_problems_set)):
                    if add_time:
                        expansions.append((problem_properties[SEARCH_TIME],
                                           problem_properties[EXPANSIONS]))
                    else:
                        expansions.append(problem_properties[EXPANSIONS])
            all_expansions.append(expansions)

    data = []
    labels = []
    for i in range(len(algorithms)):
        if len(all_expansions[i]) > 0:
            data.append(all_expansions[i])
            labels.append(algorithms[i])
    data = np.array(data)
    labels = np.array(labels)

    column_assigner = (lambda x: 1) if column_assigner is None else column_assigner
    nb_columns = 1 if nb_columns is None else nb_columns
    columns = {}
    baselines = []

    for d, l in zip(data, labels):
        if l in baseline_algorithms:
            baselines.append((d, l))
        else:
            idx_col = column_assigner(l)
            if idx_col not in columns:
                columns[idx_col] = ([], [])
            columns[idx_col][0].append(d)
            columns[idx_col][1].append(l)
    for k in columns.keys():
        for d, l in baselines:
            columns[k][0].append(d)
            columns[k][1].append(l)
    return [None if k not in columns else columns[k]
            for k in range(1, nb_columns + 1)], common_problems_set


def plot(path, structured, domain_universes, algorithms, baseline_algorithms,
         nb_columns=None, column_assigner=None):
    assert (nb_columns is None and column_assigner is None) or (
            nb_columns is not None and column_assigner is not None)
    nb_columns = 1 if nb_columns is None else nb_columns
    nb_rows = sum(len(x) for x in domain_universes.values())

    fig = plt.figure(figsize=(14 * nb_columns, 5 * nb_rows))
    idx_row = 0
    for domain in sorted(domain_universes.keys()):
        for problem in sorted(domain_universes[domain].keys()):
            columns, common_problems = get_current_data_and_labels(
                structured, domain, problem, algorithms, baseline_algorithms,
                nb_columns, column_assigner)
            for idx_col, data_labels in enumerate(columns):
                ax = fig.add_subplot(nb_rows, nb_columns, idx_row * nb_columns + idx_col + 1)
                ax.set_title(
                    "Expansions in domain %s, universe %s%s" %
                    (domain,
                     problem,
                     ("\n(#problems %i)" % len(common_problems))
                     if common_problems is not None else ""))
                ax.set_xlabel("#expansions")
                if data_labels is None:
                    continue
                data, labels = data_labels
                plot_ax(ax, baseline_algorithms, data, labels)
            idx_row += 1

    fig.tight_layout()
    fig.savefig(path)


def plot_speed(path, structured, domain_universes, algorithms, baseline_algorithms,
         nb_columns=None, column_assigner=None):
    assert (nb_columns is None and column_assigner is None) or (
            nb_columns is not None and column_assigner is not None)
    nb_columns = 1 if nb_columns is None else nb_columns
    nb_rows = sum(len(x) for x in domain_universes.values()) + 1

    fig = plt.figure(figsize=(14 * nb_columns, 5 * nb_rows))
    idx_row = 0

    def update_average_speed_factor(labels, coefs, max_x_value):
        if update_average_speed_factor.baseline is None:
            update_average_speed_factor.baseline = labels[0]
        if len(labels) == 0 or update_average_speed_factor.baseline not in labels:
            return
        assert all(len(coef) == 2 for coef in coefs)
        idx_baseline = labels.index(update_average_speed_factor.baseline)
        if max_x_value[idx_baseline] < 10:
            return
        for no, (label, coef) in enumerate(zip(labels, coefs)):
            if max_x_value[no] < 10:
                continue
            factor = coef[1] / coefs[idx_baseline][1]
            update_average_speed_factor.data[label] += factor
            update_average_speed_factor.count[label] += 1
    update_average_speed_factor.data = collections.defaultdict(int)
    update_average_speed_factor.count = collections.defaultdict(int)
    update_average_speed_factor.baseline = None


    for domain in sorted(domain_universes.keys()):
        for problem in sorted(domain_universes[domain].keys()):
            columns, common_problems = get_current_data_and_labels(
                structured, domain, problem, algorithms, baseline_algorithms,
                nb_columns, column_assigner, add_time=True)
            for idx_col, data_labels in enumerate(columns):
                ax = fig.add_subplot(nb_rows, nb_columns, idx_row * nb_columns + idx_col + 2)
                ax.set_title(
                    "Expansion speed in domain %s, universe %s%s" %
                    (domain,
                     problem,
                     ("\n(#problems %i)" % len(common_problems))
                     if common_problems is not None else ""))
                ax.set_xlabel("time in s")
                ax.set_ylabel("#expansions")
                if data_labels is None:
                    continue
                data, labels = data_labels
                labels, coefs, max_x_value = plot_speed_ax(
                    ax, baseline_algorithms, data, labels)
                update_average_speed_factor(labels, coefs, max_x_value)
            idx_row += 1


    speed_factors = {label: factor_sum / update_average_speed_factor.count[label] for label, factor_sum in update_average_speed_factor.data.items()}
    str_speed_factors = "Average Expansion Speed Multiplier:\n%s" %(
        "\n".join("%s: %.2f" % (label, factor) for label, factor in speed_factors.items())
    )
    ax = fig.add_subplot(nb_rows, nb_columns, 1)
    ax.text(0, 0, str_speed_factors)
    fig.tight_layout()
    fig.savefig(path)



def compare_expansions(path, structured, filter_algorithms,
                       baseline_algorithms=None,
                       nb_columns=None, column_assigner=None,
                       path_speed=None):
    assert (nb_columns is None and column_assigner is None) or (
            nb_columns is not None and column_assigner is not None)
    algorithms = sorted([x for x in structured.keys() if filter_algorithms.match(x)])
    domains_universes = get_relevant_universes(structured, algorithms, baseline_algorithms)
    plot(path, structured, domains_universes, algorithms, baseline_algorithms,
         nb_columns, column_assigner)
    if path_speed is not None:
        plot_speed(
            path_speed, structured, domains_universes, algorithms,
            baseline_algorithms, nb_columns, column_assigner)


if __name__ == "__main__":
    raise RuntimeError("Not supposed to be started as __main__")