#! /usr/bin/env python3

from parse_output import parse_fairfind_output, parse_anant_output, \
    parse_aprove_output, parse_nuxmv_output

import argparse
import matplotlib
import matplotlib.pyplot as plt

pgf_with_pdflatex = {
    "pgf.texsystem": "pdflatex",
    "pgf.preamble": [
         r"\usepackage[utf8x]{inputenc}",
         r"\usepackage[T1]{fontenc}",
         ]
}
matplotlib.rcParams.update(pgf_with_pdflatex)


def plot_scatter(x, y, title=None, legend=None, title_size=12, labels_size=20,
                 marks_size=10, ticks_size=20, legend_size=20, line_width=3,
                 x_label="", y_label="",
                 marker_colors=['g', 'darkorange', 'm', 'r'],
                 marker_types=['o', 'v', '<', 'P'],
                 timeout=600):
    corr_x_data = []
    corr_y_data = []
    uleft_x_data = []
    uleft_y_data = []
    udown_x_data = []
    udown_y_data = []
    uboth_x_data = []
    uboth_y_data = []
    # todo separate it_*[1] == None from others.
    show_x_timeout = False
    show_y_timeout = False
    for it_x, it_y in zip(x, y):
        if it_x is not None and it_y is not None:
            assert it_x[0] is not None
            assert it_x[1] is not False
            assert it_x[2] is not None
            assert isinstance(it_x[0], str)
            assert isinstance(it_x[2], float), "{}".format(it_x)
            assert it_y[0] is not None
            assert it_y[1] is not False
            assert it_y[2] is not None
            assert isinstance(it_y[0], str)
            assert isinstance(it_y[2], float)
            assert it_x[0] == it_y[0]
            # assert it_x[1] is not None or it_y[1] is not None
            if it_x[2] >= timeout:
                show_x_timeout = True
            if it_y[2] >= timeout:
                show_y_timeout = True
            if it_x[1] is True and it_y[1] is True:
                corr_x_data.append(it_x[2])
                corr_y_data.append(it_y[2])
            elif (it_x[1] is None or it_x[1] == "TO") and \
                 (it_y[1] is None or it_y[1] == "TO"):
                if it_x[1] is None and it_y[1] is None:
                    print("both unknown: {}".format(it_x[0]))
                elif it_x[1] == "TO" and it_y[1] == "TO":
                    print("both timeout: {}".format(it_x[0]))
                elif it_x[1] == "TO" and it_y[1] is None:
                    print("x TO, y unknown: {}".format(it_x[0]))
                elif it_x[1] is None and it_y[1] == "TO":
                    print("x unknown, y TO: {}".format(it_x[0]))
                else:
                    assert False
                uboth_x_data.append(it_x[2])
                uboth_y_data.append(it_y[2])
            elif (it_x[1] is None or it_x[1] == "TO"):
                udown_x_data.append(it_x[2])
                udown_y_data.append(it_y[2])
            elif (it_y[1] is None or it_y[1] == "TO"):
                uleft_x_data.append(it_x[2])
                uleft_y_data.append(it_y[2])
            else:
                print(it_x)
                print(it_y)
                print()
                assert False, "{} - {}".format(it_x, it_y)

    plt.figure(0)
    ax = plt.gca()
    min_v = min(min(corr_x_data), min(corr_y_data))
    max_v = max(max(corr_x_data), max(corr_y_data))
    if uleft_x_data:
        min_v = min(min_v, min(uleft_x_data))
        max_v = max(max_v, max(uleft_x_data))
        min_v = min(min_v, min(uleft_y_data))
        max_v = max(max_v, max(uleft_y_data))
    if udown_x_data:
        min_v = min(min_v, min(udown_x_data))
        max_v = max(max_v, max(udown_x_data))
        min_v = min(min_v, min(udown_y_data))
        max_v = max(max_v, max(udown_y_data))

    for idx, (x_data, y_data) in enumerate([(corr_x_data, corr_y_data),
                                            (udown_x_data, udown_y_data),
                                            (uleft_x_data, uleft_y_data),
                                            (uboth_x_data, uboth_y_data)]):
        if x_data:
            assert len(x_data) == len(y_data)
            ax.scatter(x_data, y_data,
                       c=marker_colors[idx], marker=marker_types[idx],
                       s=marks_size,
                       label=legend[idx])

    if title:
        plt.title(title, fontsize=title_size)
    if legend:
        plt.legend(loc="best", prop={'size': legend_size}).set_draggable(True)
    if x_label:
        plt.xlabel(x_label, fontsize=labels_size)
    if y_label:
        plt.ylabel(y_label, fontsize=labels_size)
    plt.xticks(fontsize=ticks_size)
    plt.yticks(fontsize=ticks_size)

    ax.plot((min_v, max_v), (min_v, max_v), color="gray", linewidth=line_width,
            linestyle='--', alpha=0.5)

    if show_y_timeout:
        x_bounds = plt.xlim()
        ax.plot((x_bounds[0], x_bounds[1]), (600, 600), color="red",
                linewidth=line_width, linestyle='--', alpha=0.5)
        # ax.text(5, 200, "TO", size=20)
    if show_x_timeout:
        y_bounds = plt.ylim()
        ax.plot((600, 600), (y_bounds[0], y_bounds[1]), color="red",
                linewidth=line_width, linestyle='--', alpha=0.5)
        # ax.text(5, 200, "TO", size=20)

    ax.set_yscale('log')
    ax.set_xscale('log')
    # min_x = min(corr_x_data)
    # if udown_x_data:
    #     min_x = min(min_x, min(udown_x_data))
    # if uleft_x_data:
    #     min_x = min(min_x, min(uleft_x_data))
    # if uboth_x_data:
    #     min_x = min(min_x, min(uboth_x_data))

    # min_y = min(corr_y_data)
    # if udown_y_data:
    #     min_y = min(min_y, min(udown_y_data))
    # if uleft_y_data:
    #     min_y = min(min_y, min(uleft_y_data))
    # if uboth_y_data:
    #     min_y = min(min_y, min(uboth_y_data))
    ax.set_aspect('equal', adjustable='box')

    # edge = min(min_x, min_y)
    # edge = 0.2 if edge <= 1.2 else edge - 1
    # plt.xlim(left=edge)
    # plt.ylim(bottom=edge)
    plt.subplots_adjust(top=0.98, bottom=0.16, right=1, left=0,
                        hspace=0, wspace=0)
    plt.show()


def main(opts):

    # parameters for plots
    title_size = 12
    labels_size = 60
    marks_size = 400
    ticks_size = 60
    legend_size = 40

    out_fairfind = parse_fairfind_output(opts.fairfind_out_file)
    out_nuxmv = parse_nuxmv_output(opts.nuxmv_out_dir)
    out_anant = parse_anant_output(opts.anant_out_dir)
    out_aprove = parse_aprove_output(opts.aprove_out_dir)

    for name, _ in out_anant.items():
        assert name in out_fairfind, "{}".format(name)
        assert name in out_nuxmv, "{}".format(name)

    for name, _ in out_aprove.items():
        assert name in out_fairfind, "{}".format(name)
        assert name in out_nuxmv, "{}".format(name)

    for name, _ in out_fairfind.items():
        assert name in out_nuxmv, "{}".format(name)

    for name, _ in out_nuxmv.items():
        assert name in out_fairfind, "{}".format(name)

    # sort according to FairFind
    fairfind_data = [None for _ in out_fairfind]
    aprove_data = [None for _ in out_fairfind]
    anant_data = [None for _ in out_fairfind]
    nuxmv_data = [None for _ in out_fairfind]

    for idx, (name, ff_vals) in \
            enumerate(sorted(out_fairfind.items(),
                             key=lambda x: x[1][0]+x[1][1]+x[1][2]+x[1][3])):
        ff_runtime = ff_vals[0] + ff_vals[1] + ff_vals[2] + ff_vals[3]
        fairfind_data[idx] = (name, ff_vals[6], ff_runtime)
        if name in out_aprove:
            aprove_data[idx] = (name, *out_aprove[name])
        if name in out_anant:
            anant_data[idx] = (name, *out_anant[name])
        if name in out_nuxmv:
            nuxmv_data[idx] = (name, *out_nuxmv[name])

    print("Order: ")
    for idx, (ff, anant, aprove, nuxmv) in \
            enumerate(zip(fairfind_data, aprove_data, anant_data,
                          nuxmv_data)):
        assert anant is None or ff[0] == anant[0]
        assert anant is None or len(ff) == len(anant)
        assert aprove is None or ff[0] == aprove[0]
        assert aprove is None or len(ff) == len(aprove)
        assert nuxmv is None or ff[0] == nuxmv[0]
        assert nuxmv is None or len(ff) == len(nuxmv)
        print("{} - {}".format(idx, ff[0]))

    assert len([a for a in aprove_data if a is not None]) == len(out_aprove)
    assert len([a for a in anant_data if a is not None]) == len(out_anant)
    assert len([a for a in nuxmv_data if a is not None]) == len(out_nuxmv)

    print("Anant vs AProVe")
    plot_scatter(x=anant_data, y=aprove_data,
                 title_size=title_size, labels_size=labels_size,
                 marks_size=marks_size, ticks_size=ticks_size,
                 legend_size=legend_size,
                 x_label="Anant (s)", y_label="AProVe (s)",
                 legend=["both answer", "Anant undef", "AProVe undef",
                         "both undef"])

    print("FairFind vs Anant")
    plot_scatter(x=fairfind_data, y=anant_data,
                 title_size=title_size, labels_size=labels_size,
                 marks_size=marks_size, ticks_size=ticks_size,
                 legend_size=legend_size,
                 x_label="FairFind (s)", y_label="Anant (s)",
                 legend=["both answer", "FairFind undef", "Anant undef"])

    print("FairFind vs AProVe")
    plot_scatter(x=fairfind_data, y=aprove_data,
                 title_size=title_size, labels_size=labels_size,
                 marks_size=marks_size, ticks_size=ticks_size,
                 legend_size=legend_size,
                 x_label="FairFind (s)", y_label="AProVe (s)",
                 legend=["both answer", "FairFind undef", "AProVe undef"])

    print("FairFind vs nuXmv")
    plot_scatter(x=fairfind_data, y=nuxmv_data,
                 title_size=title_size, labels_size=labels_size,
                 marks_size=marks_size, ticks_size=ticks_size,
                 legend_size=legend_size,
                 x_label="FairFind (s)", y_label="nuXmv (s)",
                 legend=["both answer", "FairFind undef", "nuXmv undef"])

    if opts.scaling_bouncing_ball_out_file or \
       opts.scaling_bench19_out_file or \
       opts.scaling_example2_out_file:
        print("FairFind scaling AG-skeletons")
        plt.figure(0)
        ax = plt.gca()
        if opts.scaling_bouncing_ball_out_file:
            out_scaling = parse_fairfind_output(opts.scaling_bouncing_ball_out_file)
            x_data = [None for _ in out_scaling]
            y_data = [None for _ in out_scaling]
            for idx, (name, ff_vals) in enumerate(sorted(out_scaling.items(),
                                                         key=lambda x: x[0])):
                ff_runtime = ff_vals[0] + ff_vals[1] + ff_vals[2] + ff_vals[3]
                assert ff_vals[6] is True
                x_data[idx] = int(name[0:2])
                y_data[idx] = ff_runtime
            ax.scatter(x_data, y_data, c='g', marker='o',
                       s=marks_size, label="bouncing-ball")

        if opts.scaling_bench19_out_file:
            out_scaling = parse_fairfind_output(opts.scaling_bench19_out_file)
            x_data = [None for _ in out_scaling]
            y_data = [None for _ in out_scaling]
            for idx, (name, ff_vals) in enumerate(sorted(out_scaling.items(),
                                                         key=lambda x: x[0])):
                ff_runtime = ff_vals[0] + ff_vals[1] + ff_vals[2] + ff_vals[3]
                assert ff_vals[6] is True
                x_data[idx] = int(name[0:2])
                y_data[idx] = ff_runtime
            ax.scatter(x_data, y_data, c='r', marker='^',
                       s=marks_size, label="bench-19")

        if opts.scaling_example2_out_file:
            out_scaling = parse_fairfind_output(opts.scaling_example2_out_file)
            x_data = [None for _ in out_scaling]
            y_data = [None for _ in out_scaling]
            for idx, (name, ff_vals) in enumerate(sorted(out_scaling.items(),
                                                         key=lambda x: x[0])):
                ff_runtime = ff_vals[0] + ff_vals[1] + ff_vals[2] + ff_vals[3]
                assert ff_vals[6] is True
                x_data[idx] = int(name[0:2])
                y_data[idx] = ff_runtime
            ax.scatter(x_data, y_data, c='b', marker='s',
                       s=marks_size, label="example2")

        plt.xlabel("# AG-skeletons", fontsize=labels_size)
        plt.ylabel("FairFind (s)", fontsize=labels_size)
        plt.xticks(fontsize=ticks_size)
        plt.yticks(fontsize=ticks_size)
        plt.subplots_adjust(top=0.98, bottom=0.15, right=0.99, left=0.14,
                            hspace=0, wspace=0)
        plt.legend(loc="best", prop={'size': legend_size}).set_draggable(True)
        plt.show()

    return


def getopts():
    p = argparse.ArgumentParser()
    p.add_argument("--aprove_out_dir", type=str, required=True)
    p.add_argument("--anant_out_dir", type=str, required=True)
    p.add_argument("--nuxmv_out_dir", type=str, required=True)
    p.add_argument("--fairfind_out_file", type=str, required=True)
    p.add_argument("--scaling_bouncing_ball_out_file", type=str,
                   required=False)
    p.add_argument("--scaling_bench19_out_file", type=str, required=False)
    p.add_argument("--scaling_example2_out_file", type=str, required=False)
    return p.parse_args()


if __name__ == "__main__":
    main(getopts())
