from collections import defaultdict
import json
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import os
import re
import seaborn as snss
import sys

if sys.version_info < (3,):
    input = raw_input

if sys.version_info < (3,):
    import subprocess32 as subprocess
else:
    import subprocess

import style_cache


KEY_PATH_MODEL = "path_model"
KEY_COVERAGE = "coverage"
KEY_DOMAIN = "domain"

KEY_LOSS = "loss"
KEY_INPUT_MIN = "input_min"
KEY_INPUT_MAX = "input_max"
KEY_PREDICTION_MIN = "prediction_min"
KEY_PREDICTION_MAX = "prediction_max"
KEY_STORED = "stored"
KEY_INCREASED_WALK_LENGTH = "increased_walk_length"



PATTERN_FLOAT = r'[-+]?(?:(?:\d*\.\d+)|(?:\d+\.?))(?:[Ee][+-]?\d+)?'
PATTERN_INPUTS_PREDICTIONS = r"(?:\d+:\d+\s*(?:,\s*\d+:\d+\s*)*|\d+-\d+)"
REGEX_INPUTS_PREDICTIONS = re.compile(PATTERN_INPUTS_PREDICTIONS)
REGEX_EPOCH = re.compile("epoch:\s*(\d+)\s*-\s*loss:\s*(%s)\s*-\s*inputs:\s*(%s)\s*-\s*predictions:\s*(%s)\s*-" % (PATTERN_FLOAT, PATTERN_INPUTS_PREDICTIONS, PATTERN_INPUTS_PREDICTIONS))
REGEX_EPOCH_IDX = re.compile(r"epoch:\s*(\d+)\s*")
REGEX_INCREASE_MAX_SCRAMBLES = re.compile(r"Increased max scrambles to (\d+(?:, \d+))")


def natural_sort(l):
    convert = lambda text: int(text) if text.isdigit() else text.lower()
    alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ]
    return sorted(l, key = alphanum_key)


def parse_min_max_label(s):
    s = s.strip()
    assert REGEX_INPUTS_PREDICTIONS.match(s)
    is_list = s.find(":") > -1
    is_range = s.find("-") > -1
    assert sum([is_list, is_range]) == 1
    if is_list:
        vals = [float(x.strip().split(":")[0]) for x in s.split(",") if x.strip() != ""]
        return min(vals), max(vals)
    elif is_range:
        vals = [float(x.strip()) for x in s.split("-") if x.strip()]
        assert len(vals) == 2
        assert vals[0] < vals[1]
        return tuple(vals)
    else:
        assert False


def get_previous_epoch(content, idx):
    idx_epoch = content[:idx].rfind("epoch: ")
    if idx_epoch == -1:
        return 0
    else:
        epoch = REGEX_EPOCH_IDX.match(content[idx_epoch:])
        assert epoch is not None, content[idx_epoch: idx_epoch + 100]
        return int(epoch.group(1))


def parse_model_stored(content):
    stored = []
    idx_stored = 0
    while True:
        idx_stored = content.find("variables to const ops.", idx_stored + 1)
        if idx_stored == -1:
            break
        stored.append(get_previous_epoch(content, idx_stored))
    return stored


def parse_max_scrambles(content):
    lengths = [[int(y.strip()) for y in x.split(",") if y.strip() != ""]
               for x in REGEX_INCREASE_MAX_SCRAMBLES.findall(content)]
    assert all(len(x) == len(lengths[0]) for x in lengths)
    return lengths


def parse_attributes(content):
    attributes = {}
    epoch_attributes = REGEX_EPOCH.findall(content)
    assert len(epoch_attributes[0]) == 4
    epoch_attributes = [(
        int(ea[0]), float(ea[1]),
        parse_min_max_label(ea[2]), parse_min_max_label(ea[3]))
        for ea in epoch_attributes]

    attributes[KEY_LOSS] = [ea[1] for ea in epoch_attributes]
    attributes[KEY_INPUT_MIN] = [ea[2][0] for ea in epoch_attributes]
    attributes[KEY_INPUT_MAX] = [ea[2][1] for ea in epoch_attributes]
    attributes[KEY_PREDICTION_MIN] = [ea[3][0] for ea in epoch_attributes]
    attributes[KEY_PREDICTION_MAX] = [ea[3][1] for ea in epoch_attributes]
    attributes[KEY_STORED] = parse_model_stored(content)
    attributes[KEY_INCREASED_WALK_LENGTH] = parse_max_scrambles(content)
    return attributes


def fetch_model_attributes(experiment, file, filter):
    path_eval = experiment.path + "-eval"
    assert os.path.isdir(path_eval), path_eval
    path_properties = os.path.join(path_eval, "properties")
    assert os.path.isfile(path_properties), path_properties
    path_attributes = os.path.join(path_eval, file)
    assert not os.path.isdir(path_attributes), path_attributes
    if os.path.exists(path_attributes):
        answer = input("The model attribute file %s exists. What shall we do? "
                       "(o)verwrite, (s)top: ").strip().lower()
        assert answer in ["o", "s"], answer
        if answer == "s":
            return

    print("Load properties...")
    # {domain : {model file : [coverage, {attribute: [val1, ...]}]}}
    attributes = defaultdict(lambda: defaultdict(lambda: [0, None]))
    with open(path_properties, "r") as f:
        properties = json.load(f)
    print("Load models to check...")
    for props in properties.values():
        props = filter(props)
        if props is None:
            continue
        if KEY_PATH_MODEL not in props or KEY_COVERAGE not in props:
            continue
        path_model = props[KEY_PATH_MODEL]
        coverage = props[KEY_COVERAGE]
        domain = props[KEY_DOMAIN]
        attributes[domain][path_model][0] += coverage
    properties = None
    todo = [(d, m) for d, x in attributes.items() for m in x.keys()]
    print("To parse:", len(todo))
    last_fraction = 0
    for no, (domain, path_model) in enumerate(todo):
        fraction = no/float(len(todo))
        if fraction > last_fraction + 0.01:
            print("%.1f %% done." % (fraction * 100))
            last_fraction = fraction

        assert os.path.exists(path_model), path_model
        path_run = os.path.dirname(path_model)
        path_log_xz = os.path.join(path_run, "run.log.tar.xz")
        path_log = os.path.join(path_run, "run.log")
        assert os.path.exists(path_log_xz), path_log_xz
        subprocess.call(["tar", "-xf", path_log_xz, "--directory", path_run])
        assert os.path.exists(path_log)
        with open(path_log, "r") as f:
            log = f.read()
        attributes[domain][path_model][1] = parse_attributes(log)

    with open(path_attributes, "w") as f:
        json.dump(attributes, f)






def plot_model_performance(file, experiment, max_coverage,
                           colormap=None, postprocess_plot_data=None,
                           log_scale_attribute=None,
                           filter_model=None):
    """

    :param file: filename in the eval directory where to store the plot
    :param experiment: experiment object for which the plots shall be created
    :param filter: callable or list of callables that modify something.
    :return:
    """
    max_coverage = float(max_coverage)
    colormap = plt.get_cmap("viridis")

    path_eval = experiment.path + "-eval"
    assert os.path.isdir(path_eval), path_eval
    path_attributes = os.path.join(path_eval, "model_attributes.json")
    assert os.path.isfile(path_attributes), path_attributes
    with open(path_attributes, "r") as f:
        model_attributes = json.load(f)

    all_domains, all_attributes = set(), set()
    for domain, model_misc in model_attributes.items():
        all_domains.add(domain)
        for model, coverage_attributes in model_misc.items():
            if filter_model is not None and not filter_model(model):
                continue
            for attribute in coverage_attributes[1].keys():
                all_attributes.add(attribute)
    all_domains = natural_sort(all_domains)
    all_attributes = natural_sort(all_attributes)

    #sc = style_cache.StyleCache(main_algorithm_attributes=[style_cache.MARKER])

    plot_size = 5
    nb_subplots = len(all_attributes) * len(all_domains)
    fig = plt.figure(figsize=(len(all_attributes) * plot_size,
                              len(all_domains) * plot_size))

    def _prepare_data(_model_attributes, _domain, _attribute):
        assert _domain in _model_attributes, _domain
        _model_attributes = _model_attributes[_domain]
        _curr_models = natural_sort([
            _f for _f, _cov_attr in _model_attributes.items()
            if _attribute in _cov_attr[1] and filter_model(_f)])
        return {_f: [_model_attributes[_f][0]/max_coverage,
                     _model_attributes[_f][1][_attribute]]
                 for _f in _curr_models}


    def _plot(_ax, _domain, _attribute, _data):
        print(_domain, _attribute)
        if postprocess_plot_data is not None:
            _modified, _data = postprocess_plot_data(_data)

        for _model, _score_attribute in _data.items():
            assert 0 <= _score_attribute[0] <= 1
            marker=None
            if _attribute == "stored":
                marker = "x"
            _ax.plot(_score_attribute[1], color=colormap(_score_attribute[0]), marker=marker)
        t = "{} {}".format(_domain, _attribute)
        _ax.set_title(t + (" (summarized)" if _modified else ""))
        _ax.set_xlabel("epochs")
        _ax.set_ylabel(_attribute)

        if log_scale_attribute is not None and log_scale_attribute(_attribute):
            ax.set_yscale('log')


    def _plot_colormap(_ax):
        gradient = np.linspace(0, 1, max_coverage)
        gradient = np.vstack((gradient, gradient))
        _ax.imshow(gradient, aspect='auto', cmap=colormap)

    _plot_colormap(fig.add_subplot(len(all_domains) + 1, 1, 1))

    for no_domain, domain in enumerate(all_domains):
        for no_attribute, attribute in enumerate(all_attributes):
            ax = fig.add_subplot(
                len(all_domains) + 1, len(all_attributes),
                (no_domain + 1) * len(all_attributes) + no_attribute + 1)
            _plot(ax, domain, attribute,
                  _prepare_data(model_attributes, domain, attribute))

    fig.tight_layout()
    print("Save...")
    fig.savefig(os.path.join(path_eval, file))
    print("Done.")


class TMP(object):
    def __init__(self):
        self.path = "/home/ferber/repositories/DeePDown/misc/reinforcement_learning/experiments/data/2020-07-07-Merge-ECAI+Harder-Randomness.py"
tmp = TMP()

def postprocess_plot_data_summarize_many_entries(data):
    W = 500
    if all(len(score_attribute[1]) > 2 * W for score_attribute in data.values()):
        for score_attribute in data.values():
            attribute = score_attribute[1]
            score_attribute[1] = [
                np.mean(attribute[y - W:y + W])
                for y in
                [(x + 1) * W for x in range(int(len(attribute) / W) - 1)]
            ]
        return True, data
    return False, data

def log_scale_attribute(attribute):
    return attribute in ["loss"]

# plot_model_performance("avi.pdf", tmp, 10, filter_model=filter_avi, log_scale_attribute=log_scale_attribute, postprocess_plot_data=postprocess_plot_data_summarize_many_entries)
# plot_model_performance("search.pdf", tmp, 10, filter_model=filter_search, log_scale_attribute=log_scale_attribute, postprocess_plot_data=postprocess_plot_data_summarize_many_entries)