from collections import OrderedDict
import sys, os
sys.path.append(os.getcwd())
from utils import yaml_load, bcolors,yaml_overwrite
import argparse
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from matplotlib.pyplot import cm
import os
import re
from scipy.stats import gmean
import math
import time
from plot_util import tlut_unique_component, hatch_map, ls_arr, marker_arr, \
    sys_breakdown_map, carat_breakdown_map, usys_breakdown_map, \
    hatch_map_cg, size_map, \
    color_map, opac_map, \
    yml_metric_map, reduce_map, single_multiple_map, \
    partition_layerwise_reduction_map, axis_map, \
    mlperf_name_map, legend_map, clip_map, default_map, \
    ignore_vec_map, default_ignore_vec, default_ignore_vec_area,default_for_opac, \
    my_dpi, fig_h, fig_h_short, fig_w, size_tuple, \
    y_ticks_map

########### motivation: data reuse vs value reuse #############

# matplotlib settings
font = {'serif':'Helvetica Neue', 'size': 5}
matplotlib.rc('font', **font)
matplotlib.rcParams['hatch.linewidth'] = 0.3
matplotlib.rcParams['lines.linewidth'] = 0.3
matplotlib.rcParams['axes.linewidth'] = 0.3
matplotlib.rcParams['axes.titlesize'] = 5
matplotlib.rcParams['xtick.major.width'] = 0.3
matplotlib.rcParams['ytick.major.width'] = 0.3
matplotlib.rcParams['xtick.labelsize'] = 5
matplotlib.rcParams['ytick.labelsize'] = 5
matplotlib.rcParams['xtick.major.size'] = 2
matplotlib.rcParams['ytick.major.size'] = 2
matplotlib.rcParams['legend.fontsize'] = 5
matplotlib.rcParams['legend.labelspacing'] = 0.25

if __name__ == "__main__":
    size_tuple_local = (1.65575, 0.8)
    os.makedirs('plot/motiv', exist_ok=True)

    x = np.linspace(4,8,20)
    square = [(x_*x_) for x_ in x]
    exp = [(2**x_) for x_ in x]
    print(square)
    print(exp)

    plot_arr = [square, exp]
    lab_arr = ['Data reuse', 'Value reuse']
    ls_arr = ['-', '--']
    fig, ax = plt.subplots(figsize=size_tuple_local, dpi=my_dpi)
    for num in range(len(plot_arr)):
        plt.plot(
            x, plot_arr[num],
            color=[bcolors.gray, bcolors.orange][num],
            alpha=1,
            linewidth=1.5,
            linestyle=ls_arr[num],
            label=lab_arr[num],
        )
    ax.set_xticks([4,8])
    ax.set_xticklabels(['High', 'Low'])
    ax.tick_params(axis='x', length=0)
    # display yticks
    yt = []
    ax.set_yticks(yt)
    ax.set_yticklabels(yt)
    ylab = 'Compute\nefficiency'
    ax.set_ylabel(ylab)
    xlab = 'Data precision'
    ax.set_xlabel(xlab)

    ax.legend(
            bbox_to_anchor=(0,0.9,0.5,0.1), loc="upper left",\
            mode="expand", borderaxespad=0, ncol=1)

    fig_name = f'plot/motiv/compute_cost'
    fig_name += '.pdf'
    fig.tight_layout()
    plt.savefig(fig_name, bbox_inches='tight', dpi=my_dpi, pad_inches=0.02)
    print(bcolors.OKGREEN + f'Saved fig as {fig_name}' + bcolors.ENDC)



    plot_arr = [x, x]
    lab_arr = ['Data reuse', 'Value reuse']
    fig, ax = plt.subplots(figsize=size_tuple_local, dpi=my_dpi)
    for num in range(len(plot_arr)):
        plt.plot(
            x, plot_arr[num],
            color=[bcolors.gray, bcolors.orange][num],
            alpha=1,
            linewidth=1.5,
            linestyle=ls_arr[num],
            label=lab_arr[num],
        )
    ax.set_xticks([4,8])
    ax.set_xticklabels(['High', 'Low'])
    ax.tick_params(axis='x', length=0)
    # display yticks
    yt = []
    ax.set_yticks(yt)
    ax.set_yticklabels(yt)
    # plt.tick_params(
    #     axis='x',          # changes apply to the x-axis
    #     which='both',      # both major and minor ticks are affected
    #     bottom=False,      # ticks along the bottom edge are off
    #     top=False,         # ticks along the top edge are off
    #     labelbottom=False) # labels along the bottom edge are off
    ylab = 'Memory\nefficiency'
    ax.set_ylabel(ylab)
    xlab = 'Data precision'
    ax.set_xlabel(xlab)

    # ax.legend(
    #         bbox_to_anchor=(0,0.9,0.5,0.1), loc="upper left",\
    #         mode="expand", borderaxespad=0, ncol=1)

    fig_name = f'plot/motiv/memory_cost'
    fig_name += '.pdf'
    fig.tight_layout()
    plt.savefig(fig_name, bbox_inches='tight', dpi=my_dpi, pad_inches=0.02)
    print(bcolors.OKGREEN + f'Saved fig as {fig_name}' + bcolors.ENDC)