from __future__ import division

from collections import Counter
import math
from matplotlib import pyplot as plt
import numpy as np
import os
import random
import re

PATTERN_EXPAND_NUMBERS = re.compile(r".*K(\d+)e(\d+)")
REF_H_EVOLUTION = "ref_h_evolution"
REFERENCE_KEY = "Sum Costs"

COLORS = ["r", "g", "b", "m", "c", "y", "k"]

PATTERN_MODEL = re.compile("path\s*=\s*(.*?)\.pb")

def natural_sort(list, key=lambda s:s):
    """
    Sort the list into natural alphanumeric order.
    """
    def get_alphanum_key_func(key):
        convert = lambda text: int(text) if text.isdigit() else text
        return lambda s: [convert(c) for c in re.split('([0-9]+)', key(s))]
    sort_key = get_alphanum_key_func(key)
    list.sort(key=sort_key)


def get_prefix_of_heuristic_model(heuristic):
    prefix = PATTERN_MODEL.findall(heuristic)
    assert len(prefix) == 1
    return prefix[0]

def convert_to_relative_evolution(structured):
    converted = {}  # {Domain: {Fixed Universe: {heuristic: [(reference h, predicted h)]}}}

    for algorithm, struct_domain in structured.items():
        for domain, struct_universe in struct_domain.items():
            if domain not in converted:
                converted[domain] = {}
            for fixed_universe, struct_problem in struct_universe.items():
                if fixed_universe not in converted[domain]:
                    converted[domain][fixed_universe] = {}

                for problem, properties in struct_problem.items():
                        if (REF_H_EVOLUTION in properties
                                and properties[REF_H_EVOLUTION] is not None):
                            evolution = properties[REF_H_EVOLUTION]
                            reference = evolution[REFERENCE_KEY]
                            for heuristic in evolution.keys():
                                if heuristic == REFERENCE_KEY:
                                    continue
                                prefix = get_prefix_of_heuristic_model(heuristic)

                                if prefix not in converted[domain][fixed_universe]:
                                    converted[domain][fixed_universe][prefix] = []
                                converted[domain][fixed_universe][prefix].extend(
                                    [(ref, pred) for ref, pred in
                                     zip(reference, evolution[heuristic])]
                                )

    for d in list(converted.keys()):
        for u in list(converted[d].keys()):
            for a in list(converted[d][u].keys()):
                if len(converted[d][u][a]) == 0:
                    del converted[d][u][a]
            if len(converted[d][u]) == 0:
                del converted[d][u]
        if len(converted[d]) == 0:
            del converted[d]


    return converted

#HACK
def _convert_label(label):
    m = PATTERN_EXPAND_NUMBERS.match(label)
    if m:
        expanded = int(m.group(1)) * 10 ** int(m.group(2))
        return label[:-(len(m.group(1)) + len(m.group(2)) + 1)] + str(expanded)
    return label

def _plot_deviation(ax, title, xlabel, ylabel, data, ylog=False, hlines = []):
    ax.set_title(title)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)

    if ylog:
        ax.set_yscale("log")


    # Plot data
    for hline in hlines:
        ax.axhline(hline, c="lightgray")

    natural_sort(data, key=lambda x: _convert_label(x[0]))
    for no, (algorithm, label, x, y) in enumerate(data):
        expanded_algorithm = _convert_label(algorithm)
        label = label.replace(algorithm, expanded_algorithm)

        # add again for ylog = True: mask = y != 0
        # HACK
        c = Counter(zip(x, y))
        xx, yy = zip(*c.keys())
        ss = [2 * math.sqrt(c[xp, yp]) for xp, yp in c.keys()]

        ax.scatter([v + (0.25 if no%2==0 else -0.25) + random.random() * 0.2 for v in xx],
                   [v + int((no/2))*0.5 - 0.25 + random.random() * 0.2 for v in yy],
                   c=COLORS[no], marker=",", s=ss, alpha=0.75,
                   linestyle="--" if label.find("pruneOff") > -1 else "-",
                   linewidth=1, label=label)

        # Plot average prediction line
        unique_x = set(xx)
        avg_per_h = {k: 0 for k in unique_x}
        counts_per_h = {k: 0 for k in unique_x}
        for (xp, yp), cp in c.items():
            avg_per_h[xp] += yp * cp
            counts_per_h[xp] += cp
        for k, v in counts_per_h.items():
            avg_per_h[k] /= v
        xx = sorted(unique_x)
        yy = [avg_per_h[xp] for xp in xx]
        ax.plot(xx, yy, c=COLORS[no])

    ax.legend()


def plot_per_universe(path, data, filter_algorithms=None, nb_data_columns=None,
                      data_column_assigner=None, skip_median_deviation=False,
                      filter_universe=None):
    assert (nb_data_columns is None and data_column_assigner is None) or (
            nb_data_columns is not None and data_column_assigner is not None)
    nb_data_columns = 1 if nb_data_columns is None else nb_data_columns
    data_column_assigner = ((lambda x: 1) if data_column_assigner is None
                            else data_column_assigner)
    per_axes_rel = []
    per_axes_shift = []
    for domain, struct_universes in data.items():
        for universe, struct_algorithms in struct_universes.items():
            if filter_universe is not None and not filter_universe.match(universe):
                continue
            data_rel = []
            data_shift = []
            for algorithm, data_algorithm in struct_algorithms.items():
                if (filter_algorithms is not None
                        and filter_algorithms.match(algorithm) is None):
                    continue
                algo_data = [np.array(x) for x in zip(*data_algorithm)]

                data_rel.append([algorithm,
                                 algorithm + " (mae %.2f, mse %.2f)" %
                                 ((np.absolute(algo_data[0] - algo_data[1])).mean(),
                                  ((algo_data[0] - algo_data[1])**2).mean())]
                                + [algo_data[0]] + [algo_data[1] - algo_data[0]])
                data_shift.append([algorithm,
                                   algorithm + " (mean %.2f)" % algo_data[1].mean()]
                                  + [algo_data[0]] + [algo_data[1] - algo_data[1].mean()])
            if len(data_rel) == 0:
                continue
            per_axes_rel.append([
                "Deviation of predicted heuristics to h*\n in %s-%s" %
                (domain, universe),
                "h* value",
                "deviation",
                data_rel
            ])
            per_axes_shift.append([
                "Deviation of predicted heuristics to prediction mean in\n %s-%s" %
                (domain, universe),
                "h* value",
                "deviation",
                data_shift
            ])
    assert len(per_axes_shift) == len(per_axes_rel)

    data_rows_rel = []
    data_rows_shift = []
    for row_rel, row_shift in zip(per_axes_rel, per_axes_shift):
        data_columns = {}
        for rel, shift in zip(row_rel[-1], row_shift[-1]):
            assert rel[0] == shift[0]
            idx_data_column = data_column_assigner(rel[0])
            if idx_data_column not in data_columns:
                data_columns[idx_data_column] = ([], [])
            data_columns[idx_data_column][0].append(rel)
            data_columns[idx_data_column][1].append(shift)
        data_rows_rel.append(
            [None if i not in data_columns else (row_rel[:-1] + [data_columns[i][0]])
             for i in range(1, nb_data_columns + 1)])
        data_rows_shift.append(
            [None if i not in data_columns else (row_shift[:-1] + [data_columns[i][1]])
             for i in range(1, nb_data_columns + 1)])
    assert len(data_rows_rel) == len(data_rows_shift)

    nb_rows = len(per_axes_shift)
    nb_columns = (1 if skip_median_deviation else 2) * nb_data_columns
    fig = plt.figure(figsize=(10 * nb_columns, 10 * nb_rows))

    for idx_row in range(len(data_rows_rel)):
        assert len(data_rows_rel[idx_row]) == len(data_rows_shift[idx_row])
        for idx_col in range(len(data_rows_rel[idx_row])):
            if data_rows_rel[idx_row][idx_col] is None:
                continue
            ax = fig.add_subplot(nb_rows, nb_columns, idx_row * nb_columns +
                                 idx_col * (1 if skip_median_deviation else 2) +
                                 1)
            ax_title, ax_xlabel, ax_ylabel, data_rel = data_rows_rel[idx_row][idx_col]
            _plot_deviation(ax, ax_title, ax_xlabel, ax_ylabel, data_rel, hlines=[0])

            if not skip_median_deviation:
                ax = fig.add_subplot(nb_rows, nb_columns, idx_row * nb_columns +
                                     idx_col * (1 if skip_median_deviation else 2) +
                                     2)
                ax_title, ax_xlabel, ax_ylabel, data_shift = data_rows_shift[idx_row][idx_col]
                _plot_deviation(ax, ax_title, ax_xlabel, ax_ylabel, data_shift)

    """
    for i in range(len(per_axes_rel)):
        ax = fig.add_subplot(nb_rows, nb_columns, i * 2 + 1)
        ax_title, ax_xlabel, ax_ylabel, data_rel = per_axes_rel[i]
        _plot_deviation(ax, ax_title, ax_xlabel, ax_ylabel, data_rel, hlines=[0])

    for i in range(len(per_axes_shift)):
        ax = fig.add_subplot(nb_rows, nb_columns, i * 2 + 2)
        ax_title, ax_xlabel, ax_ylabel, data_shift = per_axes_shift[i]
        _plot_deviation(ax, ax_title, ax_xlabel, ax_ylabel, data_shift)
    """

    fig.tight_layout()
    fig.savefig(path)
    plt.close(fig)


def compare_h_deviations(path_dir, structured, filter_algorithms=None):
    data = convert_to_relative_evolution(structured)

    plot_per_universe(
        os.path.join(path_dir, "deviations_per_fixed_universe.pdf"),
        data, filter_algorithms)

"""
# Test data
import random
import math
data = {
    "domain": {
        "universe": {
            "algo1": [(x,x + 1) for x in range(10)],
            "algo2": [(x, 8 + (random.random()*0.25 - 0.125) - math.sqrt(10-x)) for x in range(10)],
            "algo3": [(x, 8 + (random.random() * 0.25 - 0.125)) for x in range(10)]
        }
    }
}
"""
if __name__ == "__main__":
    raise RuntimeError("This script shall not be started as __main__")