#!/usr/bin/env python
from collections import defaultdict
import itertools
import json
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import os
import re
import seaborn as sns
import sys

import style_cache

if sys.version_info >= (3,):
    unicode = str


def natural_sort(l):
    convert = lambda text: int(text) if text.isdigit() else text.lower()
    alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ]

    def rec(k):
        if isinstance(k, str) or isinstance(k, unicode):
            return alphanum_key(k)
        else:
            try:
                return [rec(x) for x in k]
            except TypeError:
                return k

    return sorted(l, key = rec)


COLORS = ["r", "g", "b", "o"]


def plot_attribute_evolution(file, experiment, attribute, xlabel, ylabel, colors=None):
    """

    :param file: filename in the eval directory where to store the plot
    :param experiment: experiment object for which the plots shall be created
    :param attribute: attribute name or list of attributes which are
    concatenated
    :param xlabel: label for x-axis
    :param ylabel: label for y-axis
    :return:
    """
    attribute = attribute if isinstance(attribute, list) else [attribute]
    colors = COLORS if colors is None else colors
    path_eval = experiment.path + "-eval"
    path_properties = os.path.join(path_eval, "properties")
    path_plot = os.path.join(path_eval, file)
    assert os.path.isfile(path_properties)
    with open(path_properties, "r") as f:
        properties = json.load(f)
        #properties = {k: for k, v in properties.items()}
    # {Task : { Algorithm : Values }}
    stats = defaultdict(lambda: defaultdict(list))
    algorithms = set()
    for props in properties.values():
        task_key = "({}, {})".format(props["domain"], props["problem"])
        algorithm = props["algorithm"]
        values = list(itertools.chain.from_iterable(
            [((props[a] if isinstance(props[a], list) else [props[a]])
             if a in props else [])
             for a in attribute]))
        stats[task_key][algorithm] = values
        algorithms.add(algorithm)
    algorithms = {name: no for no, name in enumerate(sorted(algorithms))}
    assert len(algorithms) <= len(colors)
    fig = plt.figure(figsize=(5, 5 * len(stats)))
    for idx_task, (task, algorithm_values) in enumerate(stats.items()):
        ax = fig.add_subplot(len(stats), 1, idx_task + 1)
        ax.set_title(task)
        max_nb_values = max(len(x) for x in algorithm_values.values())
        for algorithm, values in algorithm_values.items():
            if len(values) == 0:
                continue
            scale = max_nb_values / len(values)
            ax.plot([i * scale for i in range(len(values))], values,
                       label=algorithm, color=colors[algorithms[algorithm]])
            ax.scatter([i * scale for i in range(len(values))], values,
                    color=colors[algorithms[algorithm]])
        ax.legend()
        ax.set_xlabel(xlabel)
        ax.set_ylabel(ylabel)
    fig.savefig(path_plot)


regex_step_size = re.compile(r".*-steps-(\d+)-.*")

def plot_density(file, experiment, attribute="h_samples", rename=None):
    rename = [] if rename is None else rename
    path_eval = experiment.path + "-eval"
    path_properties = os.path.join(path_eval, "properties")
    path_plot = os.path.join(path_eval, file)
    assert os.path.isfile(path_properties)

    with open(path_properties, "r") as f:
        properties = json.load(f)
    data = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: None)))  # {task: {algorithm: { walk length : [distances] } } }
    data_none_walk_length = defaultdict(lambda: defaultdict(lambda: None))  # {task: {algorithm: [distances] } }
    all_algorithms = set()
    all_tasks = set()
    all_walk_lengths = set()

    for props in properties.values():
        algo = props["algorithm"]
        walk_length = regex_step_size.match(algo)
        algo = re.sub("steps-\d+-", "-steps-X-", algo)
        task = (props["domain"], props["problem"])


        if walk_length is None and algo.find("uniform") == -1:
            continue

        for r in rename:
            algo = r(algo)
        all_algorithms.add(algo)
        all_tasks.add(task)
        if walk_length is None:
            data_none_walk_length[task][algo] = props.get(attribute, [])
        else:
            walk_length = int(walk_length.group(1))
            all_walk_lengths.add(walk_length)
            data[task][algo][walk_length] = props.get(attribute, [])

    for task, algo_plans in data_none_walk_length.items():
        for algo, plans in algo_plans.items():
            for walk_length in all_walk_lengths:
                data[task][algo][walk_length] = plans

    def get_algo(task, snd, trd):
        return data[task][snd][trd]
    def get_walk(task, snd, trd):
        return data[task][trd][snd]

    second_argument = [
        ["algo", all_algorithms, all_walk_lengths, get_algo],
        ["distance", all_walk_lengths, all_algorithms, get_walk]
    ]
    for snd_name, snd_keys, trd_keys, getter in second_argument:
        fig = plt.figure(figsize=(5*len(snd_keys), 5*len(all_tasks)))
        for no_task, task in enumerate(sorted(all_tasks)):
            for no_snd, snd in enumerate(sorted(snd_keys)):
                ax = fig.add_subplot(len(all_tasks), len(snd_keys), no_task * len(snd_keys) + no_snd + 1)
                if no_snd == 0:
                    ax.set_ylabel(task)
                else:
                    ax.set_ylabel("density")
                if no_task == 0:
                    ax.set_title(snd)
                ax.set_xlabel("H value plans")
                for no_trd, trd in enumerate(sorted(trd_keys)):
                    tmp = getter(task, snd, trd)
                    if tmp is not None:

                        elems = [d for d, c in zip(*tmp) for _ in range(c)]
                        label = "%s(%i)" % (trd, len(elems))
                        if label.find("Boff") > -1:
                            linestyle = "-"
                        elif label.find("BhffMax") > -1:
                            linestyle = "--"
                        elif label.find("BhffProb") > -1:
                            linestyle = "-."
                        else:
                            linestyle = "-"
                        sns.distplot(elems, hist=False, kde=True,
                                     kde_kws={'linewidth': 3, "linestyle": linestyle},
                                     label=label, ax=ax)
        fig.tight_layout()
        fig.savefig(path_plot % snd_name)



def plot_facts_distribution(file, experiment, rename=None, alpha=0.6, limit_max_tuple_size=1):
    rename = [] if rename is None else rename
    path_eval = experiment.path + "-eval"
    path_properties = os.path.join(path_eval, "properties")
    path_plot = os.path.join(path_eval, file)
    assert os.path.isfile(path_properties)

    with open(path_properties, "r") as f:
        properties = json.load(f)

    # [{task: {algorithm_incl_walklength: {(fact tuple): chance} } }_tuple_size_i]
    data = None
    all_algorithms = set()
    all_tasks = set()

    max_tuple_size = None
    all_facts = defaultdict(lambda: None)  # {task : [set((facts,...), ...)_tuple_size_i, ...]}
    pk = list(properties.keys())
    for ppk in pk:
        props = properties[ppk]
        del properties[ppk]
        algo = props["algorithm"]
        task = (props["domain"], props["problem"])

        for r in rename:
            algo = r(algo)
        all_algorithms.add(algo)
        all_tasks.add(task)

        nb_states = float(props.get("nb_samples"))
        if nb_states is None:
            continue
        fact_frequencies = props["fact_frequencies"]

        if max_tuple_size is None:
            max_tuple_size = min(limit_max_tuple_size, len(fact_frequencies))
        else:
            assert min(limit_max_tuple_size, len(fact_frequencies)) == max_tuple_size

        if data is None:
            # [{task: {algorithm_incl_walklength: {(fact tuple): chance} } }_tuple_size_i]
            data = [defaultdict(lambda: defaultdict(lambda: dict()))
                    for _ in range(max_tuple_size)]

        if all_facts[task] is None:
            all_facts[task] = [set() for _ in range(max_tuple_size)]

        for no_tuple_size in range(max_tuple_size):
            ff = fact_frequencies[no_tuple_size]

            assert len(data[no_tuple_size][task][algo]) == 0
            for fact_tuple, fact_tuple_count in ff.items():
                fact_tuple = fact_tuple.split("@")
                assert fact_tuple[-1] == ""
                fact_tuple = tuple(sorted(fact_tuple[:-1]))
                all_facts[task][no_tuple_size].add(fact_tuple)
                data[no_tuple_size][task][algo][fact_tuple] = fact_tuple_count/nb_states

    fig = plt.figure(figsize=(5*max_tuple_size, 5*len(all_tasks)))
    for no_task, task in enumerate(natural_sort(all_tasks)):
        for no_tuple_size, tuple_size in enumerate(range(max_tuple_size)):
            all_facts_sorted = sorted(all_facts[task][no_tuple_size])

            ax = fig.add_subplot(len(all_tasks), max_tuple_size, no_task * max_tuple_size + no_tuple_size + 1)
            if no_tuple_size == 0:
                ax.set_ylabel(task)
            else:
                ax.set_ylabel("probability")
            if no_task == 0:
                ax.set_title("Fact Tuple Size: %i" % tuple_size)
            ax.set_xlabel("fact x-tuple")
            for no_algo, algo in enumerate(natural_sort(all_algorithms)):
                tmp = np.array([data[no_tuple_size][task][algo].get(ft, 0) for ft in all_facts_sorted])
                if len(tmp) == 0:
                    continue
                print(no_task,no_tuple_size, no_algo, len(tmp))
                # elems = [no_fact_tuple
                #          for no_fact_tuple, count_fact_tuple in enumerate(tmp)
                #          for _ in range(count_fact_tuple)]

                label = algo
                color = None
                if label.find("uniform") > -1:
                    color = "green"
                elif label.find("progression") > -1:
                    color = "black"
                elif label.find("regression") > -1:
                    if label.find("200") >-1:
                        color = "red"
                    else:
                        color = "blue"

                if label.find("h2") > -1:
                    linestyle = "--"
                else:
                    linestyle = "-"

                ax.plot(np.arange(len(tmp)), tmp, label=label, color=color, linestyle=linestyle, alpha=alpha)
                # sns.lineplot(np.arange(len(tmp)), tmp, ax=ax, label=label, color=color, hue_kws={'ls':[linestyle]}, alpha=alpha)
                # sns.distplot(elems, hist=True, kde=False,
                #              kde_kws={'linewidth': 3, "linestyle": "-"},
                #              label=label, ax=ax)
                ax.legend()
    fig.tight_layout()
    fig.savefig(path_plot)



def load_properties(path_properties, attribute, rename_algorithms=None, filter=None):
    assert os.path.isfile(path_properties)
    rename_algorithms = {} if rename_algorithms is None else rename_algorithms

    with open(path_properties, "r") as f:
        properties = json.load(f)

    data = defaultdict(
        lambda: defaultdict(list))  # {domain: {algo: [attribute]}}

    all_algorithms = set()
    for props in properties.values():
        if filter is not None:
            props = filter(props)
            if props is None:
                continue
        cov = props.get("coverage")
        if cov is None or cov != 1:
            continue

        if callable(rename_algorithms):
            algo = rename_algorithms(props)["algorithm"]
        else:
            algo = rename_algorithms.get(props["algorithm"], props["algorithm"])
        if algo is None:
            continue
        domain = props["domain"]
        data[domain][algo].append(props[attribute])
        all_algorithms.add(algo)
    all_algorithms = natural_sort(all_algorithms)
    return data, all_algorithms


def plot_coverageXattribute(file, experiment, attribute, rename_algorithms=None,
                            only_summary=False, plot_domains=False, x_log=False, filter=None,
                            color_by_coverage=None, colormap=matplotlib.cm.viridis):
    """

    :param file: filename in the eval directory where to store the plot
    :param experiment: experiment object for which the plots shall be created
    :param attribute: attribute name or list of attributes which are
    concatenated
    :param xlabel: label for x-axis
    :param ylabel: label for y-axis
    :return:
    """
    assert not only_summary or not plot_domains
    path_eval = experiment.path + "-eval"
    path_properties = os.path.join(path_eval, "properties")

    data, all_algorithms = load_properties(path_properties, attribute, rename_algorithms, filter)
    sc = style_cache.StyleCache(
        main_algorithm_attributes=[style_cache.COLOR],
        algorithm_groups=style_cache.GROUP_ALGORITHMS_BRACKETS)

    for algo in all_algorithms:
        sc.get_style(algo)

    def _plot(ax, pdata, title):
        print("Plot", title)
        max_attr_value = max(0 if algorithm not in pdata else max(pdata[algorithm])
                             for algorithm in all_algorithms)
        min_attr_value = min(0 if algorithm not in pdata else min(pdata[algorithm])
                             for algorithm in all_algorithms)
        if color_by_coverage is not None:
            if color_by_coverage is True:
                max_coverage = max(
                    0 if algorithm not in pdata else len(pdata[algorithm])
                    for algorithm in all_algorithms)
            else:
                max_coverage = color_by_coverage


        for algorithm in all_algorithms:
            if algorithm not in pdata:
                continue
            sorted_data = sorted(pdata[algorithm])
            style = sc.get_style(algorithm)
            color = style[style_cache.COLOR]
            linestyle = style[style_cache.LINE_STYLE]
            if color_by_coverage is not None:
                coverage = min(len(sorted_data), max_coverage)
                color_value = int(256 * coverage/float(max_coverage))
                color = colormap(color_value)
                linestyle = "-" if algorithm.find("avi") > -1 else "--"
            ax.plot([min_attr_value if x_log else 0] + sorted_data, np.arange(len(pdata[algorithm]) + 1), color=color, linestyle=linestyle, markevery=0.25, label=algorithm, fillstyle="none")
            ax.plot([sorted_data[-1], max_attr_value], [len(pdata[algorithm]), len(pdata[algorithm])], color=color, linestyle=linestyle, fillstyle="none", alpha=0.3)
        ax.set_title(title)
        ax.set_xlabel(attribute.replace("_", " "))
        ax.set_ylabel("coverage")
        ax.legend(bbox_to_anchor=(1.05, 1),
                   loc='upper left')
        # ax.legend()
        if x_log:
            ax.set_xscale('log')

    # Plot
    nb_plots = (0 if only_summary or plot_domains else len(data)) + 1
    def _get_figure():
        return plt.figure(figsize=(4, (nb_plots * 4)))

    all_data = defaultdict(list)
    for domain, algo_attributes in data.items():
        for algo, attributes in algo_attributes.items():
            all_data[algo].extend(attributes)

    if plot_domains:
        fig = _get_figure()
        ax = fig.add_subplot(1, 1, 1)
        _plot(ax, all_data, "General: Coverage X %s" % attribute)
        fig.tight_layout()
        fig.savefig(os.path.join(path_eval, file.format("general")))

        true_domains = set([os.path.dirname(x) for x in data.keys()])
        for domain in true_domains:
            all_data = defaultdict(list)
            for domain_task, algo_attributes in data.items():
                if domain == os.path.dirname(domain_task):
                    for algo, attributes in algo_attributes.items():
                        all_data[algo].extend(attributes)
            fig = _get_figure()
            ax = fig.add_subplot(1, 1, 1)
            _plot(ax, all_data,
                  "%s: Coverage X %s" % (domain, attribute))
            fig.tight_layout()
            fig.savefig(os.path.join(path_eval, file.format(domain)))

    else:
        fig = _get_figure()
        ax = fig.add_subplot(nb_plots, 1, 1)
        _plot(ax, all_data,"General: Coverage X %s" % attribute)

        if not only_summary:
            for no_domain, domain in enumerate(natural_sort(data.keys())):
                ax = fig.add_subplot(nb_plots, 1, no_domain + 2)
                _plot(ax, data[domain], "%s: Coverage X %s" % (domain, attribute))
        fig.tight_layout()
        fig.savefig(os.path.join(path_eval, file))
