#!/usr/bin/env python
from __future__ import print_function

import matplotlib as mpl
mpl.use('agg')

import json
import matplotlib.pyplot as plt
import numpy as np
import os
import re
from scipy.special import comb
import sys

NEW_BEST_HEURISTIC = "New best heuristic value for "
NEW_REFERENCE_HEURISTIC = "New reference heuristic value for "
G_VALUE = "[g="
PATTERN_EXPANDED = re.compile("(\d+)\sexpanded")

CHOICES = "Choices: "


def mergeSortInversions(arr):
    if len(arr) == 1:
        return arr, 0
    else:
        a = arr[:int(len(arr)/2)]
        b = arr[int(len(arr)/2):]

        a, ai = mergeSortInversions(a)
        b, bi = mergeSortInversions(b)
        c = []

        i = 0
        j = 0
        inversions = 0 + ai + bi

    while i < len(a) and j < len(b):
        if a[i] <= b[j]:
            c.append(a[i])
            i += 1
        else:
            c.append(b[j])
            j += 1
            inversions += (len(a)-i)

    c += a[i:]
    c += b[j:]

    return c, inversions


def sanitize_name(name):
    if name.startswith("nh"):
        return "nh"
    else:
        return name

def extract_new_heuristic(line):
    name = None
    for prefix in [NEW_BEST_HEURISTIC, NEW_REFERENCE_HEURISTIC]:
        if line.startswith(prefix):
            idx_double_point = line.find(":", len(prefix))
            name = sanitize_name(line[len(prefix):idx_double_point])
            break
    assert name is not None
    h = int(line[idx_double_point + 1:])
    return name, h

def analyse_h_values(content):
    data_points = []
    data_point = None
    for line in content.split("\n"):
        if any([line.startswith(x) for x in [NEW_BEST_HEURISTIC, NEW_REFERENCE_HEURISTIC]]):
            if data_point is None:
                data_point = {}
            name, h = extract_new_heuristic(line)
            data_point[name] = h
        elif data_point is not None:
            assert line.startswith(G_VALUE)
            data_point["g"] = int(line[len(G_VALUE):line.find(",")])

            expanded = PATTERN_EXPANDED.findall(line)
            assert len(expanded) == 1
            data_point["expanded"] = int(expanded[0])

            data_points.append(data_point)
            data_point = None
        else:
            pass
    return data_points


def get_good_split(s):
    parts = []
    level = 0
    buffer = ""
    for c in s:
        buffer += c
        if c == "(" or c == "[":
            level += 1
        elif c == ")" or c == "]":
            level -= 1
        elif c == "," and level == 0:
            parts.append(buffer[:-1])
            buffer = ""
    parts.append(buffer)
    return parts

def get_values_choice(s):
    choice = {}
    for p in get_good_split(s):
        if p == "":
            continue
        parts = p.rsplit("=",1)
        assert len(parts) == 2
        choice[sanitize_name(parts[0])] = int(parts[1])
    return choice


def analyse_choices(content):
    choice_points = []
    for line in content.split("\n"):
        if line.startswith(CHOICES):
            line = line[len(CHOICES):]
            line = line.replace("(NOOOOO", "")
            if line == "":
                continue
            if not (line[0] == "(" and line[-1] == ")"):
                print(line)
            assert line[0] == "(" and line[-1] == ")"
            parts = line[1:-1].split(")(")
            choice_point = []
            for part in parts:
                choice_point.append(get_values_choice(part))
            choice_points.append(choice_point)
    return choice_points


def convert_h_progress(h_progress):
    data = np.ndarray(shape=(len(h_progress), len(h_progress[0]) - 1), dtype=np.int64)
    for no_row, progress in enumerate(h_progress):
        no_col = 0
        for key in sorted(progress.keys()):
            if key == "expanded":
                continue
            data[no_row, no_col] = progress[key]
            no_col += 1

    labels = [x for x in sorted(progress.keys()) if x != "expanded"]
    return data, labels

def convert_choices_plain(choices):
    data = []
    labels = None
    for choice in choices:
        for state in choice:
            if labels is None:
                labels = state.keys()
            else:
                assert state.keys() == labels
            data.append([state[x] for x in sorted(state.keys())])
    data = np.array(data, dtype=np.int64)
    labels = sorted(labels)

    return data, labels


def plot_evolution_h_g(data_labels, main_heuristic, base_path):
    data, labels = data_labels
    fig = plt.figure()
    ax = fig.add_subplot(111)

    for idx_data in range(len(labels)):
        ax.plot(data[:, idx_data], label=labels[idx_data])

    ax.legend()
    ax.set_title("Progress of heuristic estimate using as heuristic %s" %
                 main_heuristic)
    ax.set_xlabel("")
    ax.set_ylabel("heuristic estimate (resp. g value)")

    #fig.tight_layout()
    fig.savefig(base_path + "_evolution_progress.pdf")
    plt.close(fig)


def plot_barplot_accurracy(h_progress, main_heuristic, base_path,
                           diff_mean_to_prediction=False):
    data, labels = h_progress
    predicted = data[:, labels.index(main_heuristic)]
    original = data[:, labels.index("h*")]

    by_h = {}
    min_h = min(original)
    max_h = max(original)
    for idx in range(len(original)):
        h = original[idx]
        p = predicted[idx]
        if h not in by_h:
            by_h[h] = []
        by_h[h].append(p - h)

    new_x_ticks = []
    data = []
    means = []
    mean = original.mean() if diff_mean_to_prediction else None
    for i in range(min_h, max_h + 1):
        if diff_mean_to_prediction:
            means.append(mean - i)
        if i in by_h:
            new_x_ticks.append("%d\n$n=%d$" % (i, len(by_h[i])))
            data.append(by_h[i])
        else:
            new_x_ticks.append("%d" % i)
            data.append([float('nan')])
    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1)
    ax.boxplot(data)
    ax.set_xticklabels(new_x_ticks)
    if diff_mean_to_prediction:
        ax.scatter(np.arange(len(data)) + 1, means, color='r', alpha=0.6,
                   label="Deviation to Data Mean")
    ax.set_xlabel("h*")
    ax.set_ylabel(main_heuristic)
    ax.set_title("Estimations of %s w.r.t. h*" % main_heuristic)

    fig.tight_layout()
    fig.savefig(base_path + "_deviation_wrt_hstar.pdf")
    plt.close(fig)

def plot_histogram(h_progress, main_heuristic, base_path):
    data, labels = h_progress
    predicted = data[:, labels.index(main_heuristic)]
    original = data[:, labels.index("h*")]

    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1)

    ax.hist(predicted)

    ax.set_xlabel(main_heuristic)
    ax.set_ylabel("count")
    ax.set_title("Histogramm of estimated heuristic value of %s" % main_heuristic)

    fig.tight_layout()
    fig.savefig(base_path + "_histogramm.pdf")
    plt.close(fig)

def analyse_inversion(conv_choices, main_heuristic, base_dir, prefix):
    data, labels = conv_choices
    predicted = data[:, labels.index(main_heuristic)]
    original = data[:, labels.index("h*")]

    idx_sort = np.argsort(predicted)
    sorted_original, inversions = mergeSortInversions(list(original[idx_sort]))

    path = os.path.join(base_dir, "inversions.json")
    d = {}
    if os.path.exists(path):
        with open(path, "r") as f:
            d = json.load(f)
    if prefix not in d:
        d[prefix] = {}
    d[prefix][main_heuristic] = (inversions/comb(len(original), 2))
    with open(path, "w") as f:
        json.dump(d, f)

def analyse(content, main_heuristic, out_dir, prefix, problem="problem"):
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    base_path = os.path.join(out_dir, prefix)

    h_progress = convert_h_progress(analyse_h_values(content))
    choices = analyse_choices(content)


    plot_evolution_h_g(h_progress, main_heuristic, base_path)
    if len(choices) > 0:
        conv_choices = convert_choices_plain(choices)
        plot_barplot_accurracy(conv_choices, main_heuristic, base_path,
                               diff_mean_to_prediction=False)
        plot_histogram(conv_choices, main_heuristic, base_path)
        analyse_inversion(conv_choices, main_heuristic, out_dir, problem)




if __name__ == "__main__":
    pattern_prefix = re.compile("(p\d+_\d+)_([^\.]+)\.log")
    out_dir = "results"
    for file in os.listdir("."):

        if not file.endswith(".log"):
            continue
        m = pattern_prefix.match(file)
        if m is None:
            continue
        problem, main_heuristic = m.groups()
        with open(file, "r") as f:
            print("Start: ", file)
            analyse(f.read(), main_heuristic, out_dir, "%s_%s" %(problem, main_heuristic), problem)
            print("Done.")
