#!/usr/bin/env python3

"""
Plot results from perf_analysis
"""

from utils import expect, check_minimum_python_version
check_minimum_python_version(3, 5)

import matplotlib.pyplot as plt

from collections import OrderedDict
import argparse, sys, os, math

###############################################################################
def parse_command_line(args, description):
###############################################################################
    parser = argparse.ArgumentParser(
        usage="""\n{0} <DATAFILE> [<DATAFILE>]
OR
{0} --help

\033[1mEXAMPLES:\033[0m
    \033[1;32m# Plot data file \033[0m
    > {0} data

    \033[1;32m# Plot data for multiple architectures \033[0m
    > {0} blake bowman weaver  -d Fortran -D Initial

    \033[1;32m# Very stripped-down plot for all key archs \033[0m
    > {0} blake bowman weaver  -d Fortran -t ref -t final -v blake:CPU -v weaver:GPU -v threads: -v final:C++/Kokkos -f 28

    \033[1;32m# Plot two different code versions for a single arch  \033[0m
    > {0} <sha1>/weaver <sha2>/weaver -v commit:version
""".format(os.path.basename(args[0])),
        description=description,
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )

    parser.add_argument("datafiles", nargs="+", help="Data file to plot")

    parser.add_argument("-d", "--dashed", action="append", default=[], help="Specify items associated with this value should be dashed")

    parser.add_argument("-D", "--dotted", action="append", default=[], help="Specify items associated with this value should be dotted")

    parser.add_argument("-n", "--no-legend", action="store_true", help="Do not produce legend")

    parser.add_argument("-f", "--font", type=int, default=14, help="Set font")

    parser.add_argument("-w", "--line-width", type=int, default=3, help="Set font")

    parser.add_argument("-l", "--log-x", action="store_true", help="Label the X axis with log values")

    parser.add_argument("-L", "--log-y", action="store_true", help="Label the Y axis with log values")

    parser.add_argument("-t", "--tests", action="append", help="Only plot these tests")

    parser.add_argument("-s", "--save", help="Instead of showing the plot, save to this file")

    parser.add_argument("-v", "--varname-map", action="append", default=[], help="Override how script displays this varname, format=basename:newname , blank implies don't show it at all")

    args = parser.parse_args(args[1:])

    return args

###############################################################################
def pad_lim(lim, pad=0.05, mult=False):
###############################################################################
    if mult:
        lims = lim[0] * (1 - pad), lim[1] * (1 + pad)
    else:
        d = lim[1] - lim[0]
        delta = pad * d
        lims = lim[0] - delta, lim[1] + delta
    return lims

###############################################################################
def axis_tight_pad(pad=0.05, mult=False):
###############################################################################
    plt.axis('tight')
    xl = plt.xlim()
    yl = plt.ylim()
    plt.xlim(pad_lim(xl, pad, mult))
    return plt.ylim(pad_lim(yl, pad, mult))

###############################################################################
def subtle_grid():
###############################################################################
    plt.grid(True, lw=0.5, ls='-', color=(0.8, 0.8, 0.8), zorder=-1, which='both')
    return plt.gca().set_axisbelow(True)

###############################################################################
def applies_to_line(match_list, line_label):
###############################################################################
    for item in match_list:
        if item in line_label:
            return True

    return False

###############################################################################
class Plotter(object):
###############################################################################

    def __init__(self, datafiles, dashed, dotted, no_legend, font, line_width, log_x, log_y, tests, save, varname_map):
        # plt.rcParams.update({'font.size': font})
        # axis_tight_pad()

        self._datafiles         = datafiles
        self._dashed            = dashed
        self._dotted            = dotted
        self._scaling_var       = None
        self._scaling_factor    = None
        self._consistent        = OrderedDict()
        self._inconsistent      = []
        self._title             = None
        self._avail_colors      = ["b", "g", "r", "c", "m", "y", "k"]
        self._dashed_colors     = list(self._avail_colors)
        self._dotted_colors     = list(self._avail_colors)
        self._mach_color_map    = {}
        self._log_x             = log_x
        self._log_y             = log_y
        self._min_x             = sys.maxsize
        self._max_x             = 0
        self._tests             = tests
        self._font              = font
        self._line_width        = line_width
        self._save              = save
        self._no_legend         = no_legend

        # How to refer to things in the plot
        self._display_name_map  = {
            "ni" : "Number of columns per node",
            "nk" : "Number of vertical levels",
            "dt" : "Time step",
            "ts" : "Number of time steps",
            "threads" : "Threads",
            "ref" : "Fortran",
            "vanilla" : "Initial Kokkos",
            "final" : "Final Kokkos",
            "bowman" : "KNL",
            "blake" : "Skylake",
            "weaver" : "Nvidia_V100",
            "commit" : ""}

        for varname_display_chg in varname_map:
            try:
                orig_name, new_name = varname_display_chg.split(":")
            except Exception:
                expect(False, "Incorrect varname_map format, expect X:Y, got {}".format(varname_display_chg))

            expect(orig_name in self._display_name_map,
                   "{} is not something you can override, expected one of {}".format(orig_name, ", ".join(self._display_name_map.keys())))

            self._display_name_map[orig_name] = new_name

        for datafile in datafiles:
            with open(datafile, "r", encoding="utf-8") as fd:
                lines = [line.strip() for line in fd.readlines()]

            prov_line    = lines[0]
            test_line    = lines[1]
            result_line1 = lines[2]
            result_line2 = lines[3]

            scaling_var = test_line.split()[1]
            scaling_factor = float(result_line2.split(", ")[0]) / float(result_line1.split(", ")[0])
            first_file = self._scaling_var is None
            if first_file:
                self._scaling_var    = scaling_var
                self._scaling_factor = scaling_factor
            else:
                expect(self._scaling_var == scaling_var, "Incompatible files, scaling vars don't match")
                expect(self._scaling_factor == scaling_factor, "Incompatible files, scaling factors don't match")

            for item in prov_line.split()[1:]:
                key, value = item.split("=")
                if key == "dt":
                    value += "sec"
                if first_file:
                    self._consistent[key] = value
                else:
                    if key in self._inconsistent:
                        pass
                    elif self._consistent[key] != value:
                        del self._consistent[key]
                        self._inconsistent.append(key)

        title = []
        for key, value in self._consistent.items():
            display_name = self.get_display_name(key)
            if display_name:
                title.append("{}={}".format(self.get_display_name(key), value))

        self._title = "\n".join(title)

    def get_display_name(self, var):
        return self._display_name_map[var] if var in self._display_name_map else var

    def get_line_format(self, line_label):
        dash_line = applies_to_line(self._dashed, line_label)
        dot_line  = applies_to_line(self._dotted, line_label)
        try:
            if dash_line:
                color = self._dashed_colors.pop(0)
                expect(not dot_line, "Ambiguous line format for '{}', matched both dashed and dotted".format(line_label))
                return "{}--".format(color)
            elif dot_line:
                color = self._dotted_colors.pop(0)
                return "{}:".format(color)
            else:
                color = self._avail_colors.pop(0)
                return "{}-".format(color)
        except IndexError:
            expect(False, "Ran out of usable colors. Please use dashes or dots to differentiate categories of data")

    def plot_line(self, xs, ys, line_label, testname):
        if not(self._tests) or testname in self._tests:
            line_label = "{} {}".format(self.get_display_name(testname), line_label)
            plt.plot(xs, ys, self.get_line_format(line_label), label=line_label, linewidth=self._line_width)

    def plot_all(self):
        for datafile in self._datafiles:
            self.plotdata(datafile)

        self.finalize()

    def plotdata(self, datafile):
        xs, ys = [], []
        testname = None
        line_label = ""
        xlabel = None

        with open(datafile, "r", encoding="utf-8") as fd:
            for line in fd.readlines():
                line = line.strip()
                if "Provenance" in line:
                    prov_dict = dict([item.split("=") for item in line.split()[1:] if not item.endswith("=None")])
                    if "machine" in self._inconsistent:
                        machine = prov_dict["machine"]
                        line_label = "on {}".format(self.get_display_name(machine))
                        del prov_dict["machine"]

                    if prov_dict:
                        first = True
                        for key, value in prov_dict.items():
                            if key in self._inconsistent:
                                display_name = self.get_display_name(key).lower()
                                if display_name:
                                    if first:
                                        line_label += ","
                                        first = False

                                    line_label += " {} {}".format(value, display_name)

                elif "," in line:
                    x, y = [float(item) for item in line.split(", ")]
                    self._min_x = min(x, self._min_x)
                    self._max_x = max(x, self._max_x)
                    xs.append(math.log(x, self._scaling_factor))
                    if self._log_y:
                        ys.append(math.log(y, 2))
                    else:
                        ys.append(y / 1000.0)

                elif line != "":
                    if xs:
                        self.plot_line(xs, ys, line_label, testname)
                        xs = []
                        ys = []

                    testname, xlabel = line.split()
                    expect(xlabel == self._scaling_var, "Bad file {}".format(datafile))

        self.plot_line(xs, ys, line_label, testname)

    def finalize(self):
        subtle_grid()
        if (self._title):
            plt.title(self._title, fontsize=self._font)

        if self._log_x:
            log_base_str = str(int(self._scaling_factor)) if self._scaling_factor.is_integer() else "{:.2f}".format(self._scaling_factor)
            plt.xlabel(r"$\log_{}$ {}".format(log_base_str, self.get_display_name(self._scaling_var)), fontsize=self._font)
        else:
            distance = int(round(math.log(self._max_x / self._min_x, self._scaling_factor)))
            start = int(round(math.log(self._min_x, self._scaling_factor)))
            plt.xticks(ticks=list(range(start, start+distance+1)),
                       labels=[str(int(round(pow(self._scaling_factor, item)))) for item in range(start, start+distance+1)],
                       rotation=45, fontsize=self._font - 4)
            plt.xlabel(self.get_display_name(self._scaling_var), fontsize=self._font)

        if self._log_y:
            plt.ylabel("Log2 of columns per second", fontsize=self._font)
        else:
            plt.ylabel("Thousands of columns per second", fontsize=self._font)
        plt.yticks(fontsize=self._font - 4)
        if not self._no_legend:
            plt.legend(loc=2, framealpha=0, fontsize=self._font - 4)

        if self._save is None:
            plt.show()
        else:
            plt.savefig(self._save)

###############################################################################
def _main_func(description):
###############################################################################
    plotter = Plotter(**vars(parse_command_line(sys.argv, description)))

    plotter.plot_all()

###############################################################################

if (__name__ == "__main__"):
    _main_func(__doc__)
