from collections import OrderedDict
import sys, os
sys.path.append(os.getcwd())
from utils import yaml_load, bcolors,yaml_overwrite
import argparse
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from matplotlib.pyplot import cm
import os
import re
from scipy.stats import gmean
import math
import time
from loguru import logger
from plot_util import tlut_unique_component, hatch_map, ls_arr, marker_arr, \
    usys_breakdown_map, sys_breakdown_map, carat_breakdown_map, cris_breakdown_map, hatch_map_cg, \
    color_map, opac_map, size_map, \
    yml_metric_map, reduce_map, single_multiple_map, \
    partition_layerwise_reduction_map, axis_map, \
    mlperf_name_map, legend_map, clip_map, default_map, all_ignore_vec, \
    ignore_vec_map, default_ignore_vec, default_ignore_vec_area,default_for_opac, \
    my_dpi, fig_h, fig_h_short, fig_w, size_tuple, \
    y_ticks_map,legend_vec_map,trim_item
    
PEAK = 'peak flops per sec'
logger.disable("__main__")
# matplotlib settings
font = {'serif':'Helvetica Neue', 'size': 6}
matplotlib.rc('font', **font)
matplotlib.rcParams['hatch.linewidth'] = 0.3
matplotlib.rcParams['lines.linewidth'] = 0.3
matplotlib.rcParams['axes.linewidth'] = 0.5
matplotlib.rcParams['xtick.major.width'] = 0.5
matplotlib.rcParams['ytick.major.width'] = 0.5
matplotlib.rcParams['xtick.major.size'] = 2
matplotlib.rcParams['ytick.major.size'] = 2

def trim(point, ignore, trim_only=False):
    arr = []
    for item in point:
        new_item = trim_item(
            item=item, 
            ignore=ignore,
            trim_only=trim_only,
        )

        arr.append(new_item)
    return arr

def gen_read_sets(data):
    logger.info('*** metric: ', data)
    perf_read = set()
    cost_read = set()
    other = set()
    cost_layer_read = set()

    for metric in data:
        if metric == 'perf' or metric == 'perf_norm':
            perf_read.add('flops per sec')
        if metric == 'perf_norm':
            perf_read.add(PEAK)
            other.add(metric)
        if metric == 'util_impl':
            perf_read.add('utilization/impl')
        if metric == 'rt_abs' or metric == 'rt_norm':
            perf_read.add('runtime/impl')

        if metric == 'power_total':
            cost_read.add('power/total/total')
        if metric == 'power_onchip':
            cost_read.add('power/onchip/total')
        if metric == 'power_sram':
            cost_read.add('power/sram/total')
        if metric == 'power_dram':
            cost_read.add('power/dram/total')
        if metric == 'power_compute':
            cost_read.add('power/compute/total')
        
        if metric == 'area_total' or metric == 'area_onchip':
            cost_read.add('area/onchip')
        if metric == 'area_sram':
            cost_read.add('area/sram')
        if metric == 'area_total' or metric == 'area_dram':
            cost_read.add('area/dram')
        if metric == 'area_compute':
            cost_read.add('area/compute')
        if metric == 'area_total':
            other.add(metric)

        if metric == 'energy_total':
            cost_read.add('energy/total/total')
        if metric == 'energy_onchip':
            cost_read.add('energy/onchip/total')
        if metric == 'energy_sram':
            cost_read.add('energy/sram/total')
        if metric == 'energy_dram':
            cost_read.add('energy/dram/total')
        if metric == 'energy_compute':
            cost_read.add('energy/compute/total') 
        
        if metric == 'flops/energy':
            perf_read.add('flops')
            cost_read.add('energy/onchip/total')
            other.add(metric)
        if metric == 'flops/area':
            perf_read.add('flops')
            cost_read.add('area/onchip')
            cost_read.add('area/sram')
            cost_read.add('area/dram')
            cost_read.add('area/compute')
            other.add(metric)
        if metric == 'flops/power':
            perf_read.add('flops')
            cost_read.add('power/onchip/total')
            other.add(metric)

        if metric == 'throughput/energy':
            perf_read.add('flops per sec')
            cost_read.add('energy/onchip/total')
            other.add(metric)
        if metric == 'throughput/area':
            perf_read.add('flops per sec')
            cost_read.add('area/onchip')
            cost_read.add('area/sram')
            cost_read.add('area/dram')
            cost_read.add('area/compute')
            other.add(metric)
        if metric == 'throughput/power':
            perf_read.add('flops per sec')
            cost_read.add('power/onchip/total')
            other.add(metric)
        
        if metric == 'area_partition' or metric == 'area_breakdown':
            cost_read.add('area/onchip')
            cost_layer_read.add('compu/area')
            cost_layer_read.add('ififo/area')
            cost_layer_read.add('wfifo/area')
            cost_layer_read.add('ofifo/area')
            cost_layer_read.add('oaccu/area')
            cost_layer_read.add('isram/area')
            cost_layer_read.add('wsram/area')
            cost_layer_read.add('osram/area')
            cost_layer_read.add('waccu/area')
            cost_layer_read.add('itemp/area')
            cost_layer_read.add('osmux/area')
            cost_layer_read.add('bpipe/area')

        if metric == 'energy_partition' or metric == 'energy_breakdown':
            cost_read.add('energy/onchip/total')
            cost_layer_read.add('compu/energy/total')
            cost_layer_read.add('ififo/energy/total')
            cost_layer_read.add('wfifo/energy/total')
            cost_layer_read.add('ofifo/energy/total')
            cost_layer_read.add('oaccu/energy/total')
            cost_layer_read.add('isram/energy/total')
            cost_layer_read.add('wsram/energy/total')
            cost_layer_read.add('osram/energy/total')
            cost_layer_read.add('waccu/energy/total')
            cost_layer_read.add('itemp/energy/total')
            cost_layer_read.add('osmux/energy/total')
            cost_layer_read.add('bpipe/energy/total')
            # TODO: add more

    if len(perf_read) == 0 and len(cost_read) == 0 and len(other) == 0 and len(cost_layer_read) == 0: exit()
    return perf_read, cost_read, other, cost_layer_read

def read_perf_cost(run_set, perf_read, cost_read, other, cost_layer_read, ref_arch_name, use_universal_baseline):
    first_ref_arch_seen = False
    for n in run_set.values():
        for arch_tuple in n.items():
            arch = arch_tuple[1]
            if arch_tuple[0] == ref_arch_name:
                if use_universal_baseline == False or first_ref_arch_seen == False:
                    ref_arch = arch_tuple[1]
                if not first_ref_arch_seen:
                    first_ref_arch_seen = True
                
            for net_tuple in arch.items():
                net_name = net_tuple[0]
                net_dict = net_tuple[1]
                dir = net_dict['dir']
                if len(perf_read) != 0:
                    # process performance
                    perf_yml_file = f'runs/{dir}'+'/output/performance/workloadperf.summary.yaml'
                    perf_dict = yaml_load(perf_yml_file)
                    for perf in perf_read:
                        if perf == PEAK:
                            arch_yml_file = f'runs/{dir}'+'/input/architecture.yaml'
                            arch_dict = yaml_load(arch_yml_file)
                            wl_yml_file = f'runs/{dir}'+'/input/workload.yaml'
                            wl_dict = yaml_load(wl_yml_file)['workload']
                            for layer in wl_dict.values():
                                cycle = layer['cycle']
                                break
                            if 'ystolic' in dir: peak = get_sys_peak_flops_per_sec(arch_dict, cycle)
                            elif 'carat' in dir and 'i1' in dir:
                                chunk_size = math.log2(arch_dict['architecture']['compu']['num_instances'][1]) + 1
                                if 'fp8' in dir:
                                    data_length = 4
                                else: assert False
                                peak = get_tlut_peak_flops_per_sec(arch_dict, chunk_size, data_length)
                            elif 'carat' in dir: 
                                chunk_size = math.log2(arch_dict['architecture']['compu']['num_instances'][1])+1
                                if 'fp8' in dir:
                                    data_length = 4
                                else: assert False
                                peak = get_tlut_peak_flops_per_sec(arch_dict, chunk_size, data_length)
                            elif 'cris' in dir:
                                wl_dict = yaml_load(wl_yml_file)
                                peak = get_cris_peak_flops_per_sec(arch_dict, wl_dict, cycle)
                            else: assert False, f'{dir}'
                            net_dict[perf] = peak
                            continue

                        level = perf.split('/')
                        thing = perf_dict['overall']
                        for l in level:
                            thing = thing[l]
                        net_dict[perf] = thing
                        net_dict[perf + '/norm'] = thing / ref_arch[net_name][perf]
                
                if len(cost_read) != 0:
                    # process cost
                    cost_yml_file = f'runs/{dir}'+'/output/cost/workloadcost.summary.yaml'
                    cost_dict = yaml_load(cost_yml_file)
                    for cost in cost_read:
                        level = cost.split('/')
                        thing = cost_dict['overall']
                        for l in level:
                            thing = thing[l]
                        net_dict[cost] = thing
                        net_dict[cost + '/norm'] = thing / ref_arch[net_name][cost]
                
                if len(cost_layer_read) != 0:
                    # process cost partition
                    cost_layer_yml_file = f'runs/{dir}'+'/output/cost/workloadcost.yaml'
                    cost_layer_dict = yaml_load(cost_layer_yml_file)
                    for dict_ in cost_layer_dict.items():
                        if dict_[0] in ['technology', 'frequency', 'overall']:
                            continue
                        else: break

                    # init breakdown record
                    for cost in cost_layer_read:
                        area_or_energy = cost.split('/')[1]
                        for breakdown in hatch_map_cg.keys():
                            cost_bd = breakdown + f'/{area_or_energy}'
                            if area_or_energy == 'energy':
                                cost_bd += '/total'
                            net_dict[cost_bd] = 0

                    for cost in cost_layer_read:
                        level = cost.split('/')
                        partition = level[0]
                        area_or_energy = level[1]
                        
                        if partition in tlut_unique_component and 'ystolic' in dir:
                            continue

                        if area_or_energy == 'area':
                            thing = dict_[1]
                            total = cost_layer_dict['overall'][area_or_energy]['onchip']
                            for l in level:
                                thing = thing[l]
                            net_dict[cost] = thing
                            net_dict[cost + '/percent'] = thing / total
                            net_dict[cost + '/norm_percent'] = net_dict[cost + '/percent'] * net_dict[f'{area_or_energy}/onchip/norm']
                            if 'uSys' in dir: cost_cg = usys_breakdown_map[partition] + f'/{area_or_energy}'
                            elif 'sys' in dir: cost_cg = sys_breakdown_map[partition] + f'/{area_or_energy}'
                            elif 'carat' in dir: cost_cg = carat_breakdown_map[partition] + f'/{area_or_energy}'
                            elif 'cris' in dir: cost_cg = cris_breakdown_map[partition] + f'/{area_or_energy}'
                            else: assert False
                            net_dict[cost_cg] += thing
                            net_dict[cost_cg + '/percent'] = net_dict[cost_cg] / total
                            net_dict[cost_cg + '/norm_percent'] = net_dict[cost_cg + '/percent'] * net_dict[f'{area_or_energy}/onchip/norm']
                        
                        elif area_or_energy == 'energy':
                            net_dict[cost] = 0
                            for dict_ in cost_layer_dict.items():
                                if dict_[0] in ['technology', 'frequency', 'overall']:
                                    continue
                                else: 
                                    thing = dict_[1]
                                    for l in level:
                                        thing = thing[l]
                                    net_dict[cost] += thing
                            net_dict[cost + '/percent'] = net_dict[cost] / cost_layer_dict['overall'][area_or_energy]['onchip']['total']
                            net_dict[cost + '/norm_percent'] = net_dict[cost + '/percent'] * net_dict[f'{area_or_energy}/onchip/total/norm']
                            
                            if 'uSys' in dir: cost_cg = usys_breakdown_map[partition] + f'/{area_or_energy}/total'
                            elif 'sys' in dir: cost_cg = sys_breakdown_map[partition] + f'/{area_or_energy}/total'
                            elif 'carat' in dir: cost_cg = carat_breakdown_map[partition] + f'/{area_or_energy}/total'
                            elif 'cris' in dir: cost_cg = cris_breakdown_map[partition] + f'/{area_or_energy}/total'
                            else: assert False
                            net_dict[cost_cg] += net_dict[cost]
                            net_dict[cost_cg + '/percent'] = net_dict[cost_cg] / cost_layer_dict['overall'][area_or_energy]['onchip']['total']
                            net_dict[cost_cg + '/norm_percent'] = net_dict[cost_cg + '/percent'] * net_dict[f'{area_or_energy}/onchip/total/norm']

                
                if len(other) != 0:
                    # process other metrics
                    for m in other:
                        if m == 'perf_norm':
                            net_dict[m] = net_dict['flops per sec'] / net_dict[PEAK] * 100

                        if m == 'area_total': 
                            net_dict[m] = net_dict['area/onchip'] + net_dict['area/dram']

                        # NOTE: efficiency is in terms of onchip cost
                        if m == 'throughput/area': 
                            net_dict[m] = net_dict['flops per sec'] / \
                                (net_dict['area/onchip'])
                        if m == 'throughput/power': 
                            net_dict[m] = net_dict['flops per sec'] / net_dict['power/onchip/total']
                        if m == 'throughput/energy': 
                            net_dict[m] = net_dict['flops per sec'] / net_dict['energy/onchip/total']
                        
                        if m == 'flops/area': 
                            net_dict[m] = net_dict['flops'] / \
                                (net_dict['area/onchip'])
                        if m == 'flops/power': 
                            net_dict[m] = net_dict['flops'] / net_dict['power/onchip/total']
                        if m == 'flops/energy': 
                            net_dict[m] = net_dict['flops'] / net_dict['energy/onchip/total']

                        net_dict[m + '/norm'] = net_dict[m] / ref_arch[net_name][m]
                        
    return run_set

def rescale(run_set):
    logger.info("Rescaling...")
    net_arr = []
    for key,val in run_set.items():
        for key_,val_ in val.items():
            for key__, val__ in val_.items():
                net = key__
                if net not in net_arr: net_arr.append(net)

    found_ref_val = False
    for key,val in run_set.items():
        batch = key
        for key_,val_ in val.items():
            arch = key_
            for key__, val__ in val_.items():
                net = key__
                for key___, val___ in val__.items():
                    metric_name = key___
                    metric_val = val___
                    if '/norm' in metric_name:
                        ind = net_arr.index(net)
                        if ind == 0 and not found_ref_val:
                            ref_val = val__[metric_name.replace('/norm', '')]
                            found_ref_val = True
                        scaled_val = val__[metric_name.replace('/norm', '')] / ref_val
                        run_set[key][key_][key__][metric_name] = scaled_val
                        logger.debug(f'ind={ind}, metric={metric_name}, val={metric_val}, ref={ref_val}, new={scaled_val}, {run_set[key][key_][key__][metric_name]}')
                        break
    return run_set

def construct_argparser():
    parser = argparse.ArgumentParser(description='plot')
    parser.add_argument('-p',
                        '--point',
                        nargs='+',
                        help='config',
                        )
    parser.add_argument('-f',
                        '--file',
                        help='file name',
                        default=None)

    parser.add_argument(
                        '--group_by',
                        help='subgroup',
                        choices=['n','design'],
                        default='n')
    
    parser.add_argument('--single_n_slice',
                        help='plot n=1024 only',
                        action='store_true',
                        default=False)

    parser.add_argument('--line',
                        help='draw line plot',
                        action='store_true',
                        default=False)

    parser.add_argument('--per_model',
                        help='show result per model',
                        action='store_true',
                        default=False)

    parser.add_argument('--overlay',
                        help='overlay benchmark on plot',
                        action='store_true',
                        default=False)

    parser.add_argument('--use_universal_baseline', dest='use_universal_baseline', action='store_true')
    parser.add_argument('--no-use_universal_baseline', dest='use_universal_baseline', action='store_false')
    parser.set_defaults(use_universal_baseline=True)

    parser.add_argument('--nval',
                        help='plot n=nval only',
                        choices=[
                            '1','2','4','8','16',\
                            '32','64','128','256','512',\
                            '1024','2048','4096'\
                        ],
                        default=1024)

    parser.add_argument('-d',
                        '--dirname',
                        help='dir name',
                        default=None)

    parser.add_argument('-n',
                        '--mlperf',
                        nargs='+',
                        help='networks in mlperf\n',
                        default=['resnet50', 'ssd300_vgg16', 'UNet''bert_base_uncased', 'dlrm', 'RNNT', ]
                        )

    parser.add_argument('--nvalid',
                        nargs='+',
                        help='valid n values\n',
                        type=int,
                        default=[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024],
                        )

    parser.add_argument('-m',
                        '--metric',
                        nargs='+',
                        help='perf to plot',
                        choices=['perf', 'perf_norm', 'util_impl', \
                            'rt_abs', 'rt_norm', \
                            'area_total', 'area_onchip', 'area_dram', 'area_sram', 'area_compute',\
                            'area_partition', 'area_breakdown', 'energy_partition', 'energy_breakdown', \
                            'power_total', 'power_onchip', 'power_dram', 'power_sram', 'power_compute',\
                            'energy_total', 'energy_onchip', 'energy_dram', 'energy_sram', 'energy_compute',\
                            'throughput/area', 'throughput/energy', 'throughput/power',\
                            'flops/area', 'flops/energy', 'flops/power'
                            ],
                        default=['perf', 'perf_norm', 'util_impl', 
                            'rt_abs', 'rt_norm', \
                            'area_total', 'area_onchip', 'area_dram', 'area_sram', 'area_compute',\
                            'power_total', 'power_onchip', 'power_dram', 'power_sram', 'power_compute',\
                            'energy_total', 'energy_onchip', 'energy_dram', 'energy_sram', 'energy_compute',\
                            'throughput/area', 'throughput/energy', 'throughput/power',\
                            'flops/area', 'flops/energy', 'flops/power'
                            ],
                        )
    return parser

def get_perf_across_one_arch_from_all_n(arch_name, n_list, metric, reduce_method, run_set):
    '''
    reduce across all network in mlperf
    '''
    arr_all_n = [] # traverse all n
    for n in n_list:
        for arch in run_set[n].items():
            if arch_name != arch[0]: continue
            arr = []
            for net in arch[1].items():
                net_dict = net[1]
                v = net_dict[metric]
                arr.append(v)
            if reduce_method == 'geomean': arr_all_n.append(gmean(arr))
            elif reduce_method == 'sum' : arr_all_n.append(sum(arr))
            elif reduce_method == 'none': arr_all_n.append(arr)
            else: assert False

    return arr_all_n


def get_perf_across_one_arch_from_all_n_no_reduction(arch_name, n_list, metric, run_set):
    '''
    reduce across all network in mlperf
    '''
    arr_all_n = [] # traverse all n
    for n in n_list:
        for arch in run_set[n].items():
            if arch_name != arch[0]: continue
            arr = []
            for net in arch[1].items():
                net_dict = net[1]
                v = net_dict[metric]
                arr.append(v)
            arr_all_n.append(arr)

    return arr_all_n

def get_perf_one_arch_nval_across_all_nets(arch_name, metric, run_set, nval):
    '''
    get an array of metric across all networks in mlperf. NOTE: no reduce needed.
    '''
    arr = []
    for arch in run_set[nval].items():
        if arch_name != arch[0]: continue
        for net in arch[1].items():
            net_dict = net[1]
            v = net_dict[metric]
            arr.append(v)
    arr.append(gmean(arr))
    return arr

def plot_all_n(point, metric, run_set, n_list, dir_name, single_n_slice, nval, mlperf_arr, universal_baseline):
    '''
    plot single set bar plot
    '''
    # metric map
    if metric in yml_metric_map.keys():
        metric_in_yml = yml_metric_map[metric]
    else:
        assert False

    # reduce map
    if metric in reduce_map.keys():
        reduce_method = reduce_map[metric]
    else:
        assert False

    width=1.0/(len(point)+1)
    n_str = np.arange(len(n_list))
    net_str = np.arange(len(mlperf_arr))

    if metric in legend_map.get(dir_name, default_map): size_tuple_local = size_tuple
    else: size_tuple_local = (fig_w, fig_h_short)

    # if 'chunk' in dir_name and metric == 'perf': 
    #     size_tuple_local = (fig_w, fig_h_long)

    if size_map.get(dir_name, {}).get(metric): 
        size_tuple_local = size_map[dir_name][metric]

    # ---------plot----------
    fig, ax = plt.subplots(figsize=size_tuple_local, dpi=my_dpi)
    # ax.grid(color='grey', linestyle='--', linewidth=0.2, axis='both')
    
    # if ('chunk' in dir_name or 'batch' in dir_name) and metric == 'perf': 
    #     ax.set_xlabel('Batch size')

    ylim = (clip_map.get(dir_name, default_map)).get(metric, 100)
    y_max = 1

    color_bar = []
    lab_arr = []

    for i in np.arange(len(point)):
        name = point[i].split('/')[-1]
        name_color = name.split('_dram')[0]
        arch_type = point[i].split('/')[0]
        if single_n_slice == True:
            arr = get_perf_one_arch_nval_across_all_nets(name, metric_in_yml, run_set, nval)
        else:
            arr = get_perf_across_one_arch_from_all_n(name, n_list, metric_in_yml, reduce_method, run_set)
        arr_to_plot = arr
        if max(arr) > y_max: y_max = max(arr)

        if dir_name != None and 'dim' in dir_name:
            name_opac = trim_item(
                item=point[i],
                ignore=default_for_opac,#[False,True,True,False,False,True]
                trim_only=True
                )
        else:
            name_opac = name.split('_dram')[0]

        if 'area_' not in metric and 'energy_' not in metric:
            if single_n_slice == False:
                logger.debug(arch_type,name_color)
                color_bar_ = plt.bar(n_str+width*i, np.array(arr_to_plot), width, \
                    color=color_map[arch_type][name_color], \
                    alpha=opac_map[arch_type][name_opac], \
                    label=trim_item(
                        item=point[i],
                        ignore=ignore_vec_map.get(dir_name, default_map).get(metric, default_ignore_vec)
                        )
                    )
                plt.xticks(n_str+width*float(i)/2, n_list, rotation=20)
            else:
                logger.debug(arr_to_plot)
                color_bar_ = plt.bar(net_str+width*i, np.array(arr_to_plot), width, \
                    color=color_map[arch_type][name_color], \
                    alpha=opac_map[arch_type][name_opac], \
                    label=trim_item(
                        item=point[i],
                        ignore=ignore_vec_map.get(dir_name, default_map).get(metric, default_ignore_vec)
                        )
                    )
                plt.xticks(net_str+width*float(i)/2, mlperf_arr, rotation=20)
            if metric != 'rt_abs':
                if single_n_slice == False:
                    for j in np.arange(len(n_str)):
                        if arr_to_plot[j] >= ylim:
                            ax.annotate('%.1f' % (arr_to_plot[j]), xy=(n_str[j]+width*i-0.3*width, ylim+0.15), textcoords='data', rotation=0, annotation_clip=False)
                else:
                    for j in np.arange(len(net_str)):
                        if arr_to_plot[j] >= ylim:
                            ax.annotate('%.1f' % (arr_to_plot[j]), xy=(net_str[j]+width*i-0.3*width, ylim+0.15), textcoords='data', rotation=0, annotation_clip=False)
        else: 
            # only one set of n is enough as area is constant across n
            if single_n_slice == False:
                color_bar_ = plt.bar(n_str[0]+width*i, np.array(arr_to_plot[0]), width, \
                    color=color_map[arch_type][name_color], \
                    alpha=opac_map[arch_type][name_opac], \
                    label=trim_item(point[i],[False,True,False,False,True]))
                if arr_to_plot[0]>= ylim:
                    ax.annotate('%.1f' % (arr_to_plot[0]), xy=(n_str[0]+width*i-0.3*width, ylim+0.15), textcoords='data', rotation=0, annotation_clip=False)
            else:
                color_bar_ = plt.bar(net_str[0]+width*i, np.array(arr_to_plot[0]), width, \
                    color=color_map[arch_type][name_color], \
                    alpha=opac_map[arch_type][name_opac], \
                    label=trim_item(point[i],[False,True,False,False,True]))
                if arr_to_plot[0]>= ylim:
                    ax.annotate('%.1f' % (arr_to_plot[0]), xy=(net_str[0]+width*i-0.3*width, ylim+0.15), textcoords='data', rotation=0, annotation_clip=False)
            # print(bcolors.WARNING + 'area or energy'+bcolors.ENDC)
            plt.xticks(
                width*np.arange(len(point)),\
                trim(
                    point=point,
                    ignore=ignore_vec_map.get(dir_name, default_map).get(metric, default_ignore_vec)
                ), \
                rotation=0)
    
        color_bar += color_bar_

    if metric != 'rt_abs' and y_max > ylim:
        if ylim != 100:
            final_ylim = ylim
        else:
            final_ylim = min(ylim,y_max)
        ax.set_ylim(top=final_ylim)
    
    # display yticks based on y_ticks_map in util
    y_ticks_dict = y_ticks_map.get(dir_name, default_map)
    yt = y_ticks_dict.get(metric, default_map)
    if y_ticks_dict != default_map and \
        yt != default_map:
        ax.set_yticks(yt)
        ax.set_yticklabels(yt)

    ax.set_ylabel(axis_map[metric])
    
    if metric in legend_map.get(dir_name, default_map):
        ncol = 3
        if 'chunk' in dir_name: ncol = 4
        ax.legend(
            # color_bar,lab_arr, \
            bbox_to_anchor=(0,1.02,1,0.1), loc="lower left",\
            mode="expand", borderaxespad=0, ncol=ncol)
    else: logger.warning(f'line675: no legend for {dir_name} {metric}')

    os.makedirs('plot/mlperf', exist_ok=True)
    metric_name = metric.replace('/', '_')
    metric_name = metric_name.replace(' ', '_')

    if dir_name == None:
        if single_n_slice == True:
            fig_name = f'plot/mlperf/{metric_name}_n{nval}'
        else:
            fig_name = f'plot/mlperf/{metric_name}'
    else:
        os.makedirs(f'plot/mlperf/{dir_name}', exist_ok=True)
        if single_n_slice == True:
            fig_name = f'plot/mlperf/{dir_name}/{metric_name}_n{nval}'
        else:
            fig_name = f'plot/mlperf/{dir_name}/{metric_name}'
    
    if universal_baseline:
        fig_name += '_uni.pdf'
    else:
        fig_name += '.pdf'

    fig.tight_layout()
    plt.savefig(fig_name, bbox_inches='tight', dpi=my_dpi, pad_inches=0.02)
    # print(bcolors.OKGREEN + f'Saved fig as {fig_name}' + bcolors.ENDC)

def plot_all_n_line(point, metric, run_set, n_list, dir_name, single_n_slice, nval, mlperf_arr, universal_baseline, overlay):
    '''
    plot single set line plot
    '''
    # metric map
    if metric in yml_metric_map.keys():
        metric_in_yml = yml_metric_map[metric]
    else:
        assert False

    # reduce map
    if metric in reduce_map.keys():
        reduce_method = reduce_map[metric]
        if overlay: reduce_method = 'none'
    else:
        assert False

    width=1.0/(len(point)+1)
    n_str = np.arange(len(n_list))
    net_str = np.arange(len(mlperf_arr))

    if metric in legend_map.get(dir_name, default_map): size_tuple_local = size_tuple
    else: size_tuple_local = (fig_w, fig_h_short)

    # if 'chunk' in dir_name and metric == 'perf': 
    #     size_tuple_local = (fig_w, fig_h_long)

    if size_map.get(dir_name, {}).get(metric): 
        size_tuple_local = size_map[dir_name][metric]

    # ---------plot----------
    fig, ax = plt.subplots(figsize=size_tuple_local, dpi=my_dpi)
    # ax.grid(color='grey', linestyle='--', linewidth=0.2, axis='both')
    
    # if ('chunk' in dir_name or 'batch' in dir_name): 
    #     ax.set_xlabel('Batch size')

    ylim = (clip_map.get(dir_name, default_map)).get(metric, 100)
    # y_max = 1

    # color_bar = []
    lab_arr = []

    for i in np.arange(len(point)):
        name = point[i].split('/')[-1]
        name_color = name.split('_dram')[0]
        arch_type = point[i].split('/')[0]
        if single_n_slice == True:
            arr = get_perf_one_arch_nval_across_all_nets(name, metric_in_yml, run_set, nval)
        else:
            arr = get_perf_across_one_arch_from_all_n(name, n_list, metric_in_yml, reduce_method, run_set)
        arr_to_plot = arr
        # if max(arr) > y_max: y_max = max(arr)

        if dir_name != None and 'dim' in dir_name:
            name_opac = trim_item(
                item=point[i],
                ignore=default_for_opac,#[False,True,True,False,False,True]
                trim_only=True
                )
        else:
            name_opac = name.split('_dram')[0]

        if 'area_' not in metric and 'energy_' not in metric:
            if single_n_slice == False:
                logger.debug(arch_type,name_color)
                if overlay:
                    for num in range(len(arr_to_plot[0])):
                        arr_to_plot_ = [arr[num] for arr in arr_to_plot]
                        logger.debug(arr_to_plot_)
                        plt.plot(n_str, np.array(arr_to_plot_), 
                        color=color_map[arch_type][name_color],
                        alpha=opac_map[arch_type][name_opac]-num*0.2,
                        marker=marker_arr[num],
                        markersize=3,
                        linewidth=1.5,
                        linestyle=ls_arr[num],
                        label=trim_item(
                            item=point[i],
                            ignore=[True, True, True, True, True]
                            )+f'-{mlperf_arr[num]}')
                else:
                    plt.plot(n_str, np.array(arr_to_plot), 
                        color=color_map[arch_type][name_color],
                        alpha=opac_map[arch_type][name_opac],
                        marker='o',
                        markersize=3,
                        linewidth=1.5,
                        label=trim_item(
                            item=point[i],
                            ignore=ignore_vec_map.get(dir_name, default_map).get(metric, default_ignore_vec)
                            ))
                plt.xticks(n_str, n_list, rotation=20)
            else:
                pass
                # print("nothing")
            
            # if metric != 'rt_abs':
            #     if single_n_slice == False:
            #         for j in np.arange(len(n_str)):
            #             if arr_to_plot[j] >= ylim:
            #                 print(f'value {arr_to_plot[j]} out of bound={ylim}! use annotation')
            #                 ax.annotate('%.1f' % (arr_to_plot[j]), xy=(n_str[j]+width*i-0.3*width, ylim+0.15), textcoords='data', rotation=0, annotation_clip=False)
            #     else:
            #         for j in np.arange(len(net_str)):
            #             if arr_to_plot[j] >= ylim:
            #                 ax.annotate('%.1f' % (arr_to_plot[j]), xy=(net_str[j]+width*i-0.3*width, ylim+0.15), textcoords='data', rotation=0, annotation_clip=False)
        else: 
            # only one set of n is enough as area is constant across n
            if single_n_slice == False:
                # color_bar_ = plt.bar(n_str[0]+width*i, np.array(arr_to_plot[0]), width, \
                #     color=color_map[arch_type][name_color], \
                #     alpha=opac_map[arch_type][name_opac], \
                #     label=trim_item(point[i],[False,True,False,False,True]))
                plt.plot(n_str[0], np.array(arr_to_plot[0]),
                    color=color_map[arch_type][name_color],
                    alpha=opac_map[arch_type][name_opac],
                    marker='o',
                    markersize=3,
                    linewidth=1.5,
                    label=trim_item(point[i],[False,True,False,False,True]))
                if arr_to_plot[0]>= ylim:
                    ax.annotate('%.1f' % (arr_to_plot[0]), xy=(n_str[0]+width*i-0.3*width, ylim+0.15), textcoords='data', rotation=0, annotation_clip=False)
            else:
                # color_bar_ = plt.bar(net_str[0]+width*i, np.array(arr_to_plot[0]), width, \
                #     color=color_map[arch_type][name_color], \
                #     alpha=opac_map[arch_type][name_opac], \
                #     label=trim_item(point[i],[False,True,False,False,True]))
                plt.plot(net_str[0]+width*i, np.array(arr_to_plot[0]),
                    color=color_map[arch_type][name_color],
                    alpha=opac_map[arch_type][name_opac],
                    marker='o',
                    markersize=3,
                    linewidth=1.5,
                    label=trim_item(point[i],[False,True,False,False,True]))
                if arr_to_plot[0]>= ylim:
                    ax.annotate('%.1f' % (arr_to_plot[0]), xy=(net_str[0]+width*i-0.3*width, ylim+0.15), textcoords='data', rotation=0, annotation_clip=False)
            # print(bcolors.WARNING + 'area or energy'+bcolors.ENDC)
            plt.xticks(
                width*np.arange(len(point)),\
                trim(
                    point=point,
                    ignore=ignore_vec_map.get(dir_name, default_map).get(metric, default_ignore_vec)
                ), \
                rotation=0)
    
        # color_bar += color_bar_

    # if metric != 'rt_abs' and y_max > ylim:
    #     if ylim != 100:
    #         final_ylim = ylim
    #     else:
    #         final_ylim = min(ylim,y_max)
    #     ax.set_ylim(top=final_ylim)
    
    # display yticks based on y_ticks_map in util
    y_ticks_dict = y_ticks_map.get(dir_name, default_map)
    yt = y_ticks_dict.get(metric, default_map)
    if y_ticks_dict != default_map and \
        yt != default_map:
        ax.set_yticks(yt)
        ax.set_yticklabels(yt)

    ax.set_ylabel(axis_map[metric])
    
    if metric in legend_map.get(dir_name, default_map):
        ncol = 3
        if 'chunk' in dir_name: ncol = 4
        ax.legend(
            # color_bar,lab_arr, \
            bbox_to_anchor=(0,1.02,1,0.1), loc="lower left",\
            mode="expand", borderaxespad=0, ncol=ncol)
    else: logger.debug(f'line838: no legend for {dir_name} {metric}')

    os.makedirs('plot/mlperf', exist_ok=True)
    metric_name = metric.replace('/', '_')
    metric_name = metric_name.replace(' ', '_')

    if dir_name == None:
        if single_n_slice == True:
            fig_name = f'plot/mlperf/{metric_name}_n{nval}'
        else:
            fig_name = f'plot/mlperf/{metric_name}'
    else:
        os.makedirs(f'plot/mlperf/{dir_name}', exist_ok=True)
        if single_n_slice == True:
            fig_name = f'plot/mlperf/{dir_name}/{metric_name}_n{nval}'
        else:
            fig_name = f'plot/mlperf/{dir_name}/{metric_name}'
    
    if universal_baseline:
        fig_name += '_uni.pdf'
    else:
        fig_name += '.pdf'

    fig.tight_layout()
    plt.savefig(fig_name, bbox_inches='tight', dpi=my_dpi, pad_inches=0.02)
    # print(bcolors.OKGREEN + f'Saved fig as {fig_name}' + bcolors.ENDC)

def plot_all_n_line_permodel(point, metric, run_set, n_list, dir_name, single_n_slice, nval, mlperf_arr, universal_baseline, overlay):
    '''
    plot single set line plot per model
    '''
    # metric map
    if metric in yml_metric_map.keys():
        metric_in_yml = yml_metric_map[metric]
    else:
        assert False

    # reduce map
    if metric in reduce_map.keys():
        reduce_method = reduce_map[metric]
        if overlay: reduce_method = 'none'
    else:
        assert False

    width=1.0/(len(point)+1)
    n_str = np.arange(len(n_list))
    net_str = np.arange(len(mlperf_arr))

    if metric in legend_map.get(dir_name, default_map): size_tuple_local = size_tuple
    else: size_tuple_local = (fig_w, fig_h_short)

    # if 'chunk' in dir_name and metric == 'perf': 
    #     size_tuple_local = (fig_w, fig_h_long)

    if size_map.get(dir_name, {}).get(metric): 
        size_tuple_local = size_map[dir_name][metric]

    # ---------plot----------
    fig, ax = plt.subplots(figsize=size_tuple_local, dpi=my_dpi)
    # ax.grid(color='grey', linestyle='--', linewidth=0.2, axis='both')
    
    # if ('chunk' in dir_name or 'batch' in dir_name): 
    #     ax.set_xlabel('Batch size')

    ylim = (clip_map.get(dir_name, default_map)).get(metric, 100)
    # y_max = 1

    # color_bar = []
    lab_arr = []

    for i in np.arange(len(point)):
        name = point[i].split('/')[-1]
        name_color = name.split('_dram')[0]
        arch_type = point[i].split('/')[0]
        assert single_n_slice == False
        arr = get_perf_across_one_arch_from_all_n_no_reduction(name, n_list, metric_in_yml, run_set)
        arr_to_plot = arr
        for arr_sub in arr:
            arr_sub.append(gmean(arr_sub))
        logger.debug(f'{name}, {arr_to_plot}, [n,net]=[{len(arr_to_plot)}, {len(arr_to_plot[0])}]')

        if dir_name != None and 'dim' in dir_name:
            name_opac = trim_item(
                item=point[i],
                ignore=default_for_opac,#[False,True,True,False,False,True]
                trim_only=True
                )
        else:
            name_opac = name.split('_dram')[0]

        if 'area_' not in metric and 'energy_' not in metric:
            assert single_n_slice == False
            logger.debug(arch_type,name_color)
            if overlay:
                logger.info('overlay')
                for num in range(len(arr_to_plot[0])):
                    arr_to_plot_ = [arr[num] for arr in arr_to_plot]
                    logger.debug(arr_to_plot_)
                    exit(1)
                    plt.plot(n_str, np.array(arr_to_plot_), 
                    color=color_map[arch_type][name_color],
                    alpha=opac_map[arch_type][name_opac]-num*0.2,
                    marker=marker_arr[num],
                    markersize=3,
                    linewidth=1.5,
                    linestyle=ls_arr[num],
                    label=trim_item(
                        item=point[i],
                        ignore=[True, True, True, True, True]
                        )+f'-{mlperf_arr[num]}')
            else:
                logger.info('no overlay')
                n_str_expanded = np.arange((len(n_list)+1)*len(arr_to_plot[0])).reshape(len(arr_to_plot[0]),len(n_list)+1)
                arr_to_plot_t = np.transpose(np.array(arr_to_plot))
                # logger.debug(f'{n_str_expanded}(shape={n_str_expanded.shape}), {arr_to_plot_t}(shape={arr_to_plot_t.shape})')
                ind = 0
                for n,arr in zip(n_str_expanded,arr_to_plot_t):
                    arr = arr.tolist()
                    arr.append(None)
                    lab_trim = trim_item(
                            item=name_color,
                            ignore=ignore_vec_map.get(dir_name, default_map).get(metric, all_ignore_vec),
                            trim_only=False
                            )
                    # logger.debug(lab_trim)
                    marker = 'o'
                    linesytle = '-'
                    opacity = opac_map[arch_type][name_opac]
                    if 'carat' in arch_type and '64' in name_color:
                        marker = '^'
                        linesytle = ':'
                        opacity = 0.5
                    plt.plot(n, arr, 
                        linesytle,
                        color=color_map[arch_type][name_color],
                        alpha=opacity,
                        marker=marker,
                        markersize=3,
                        linewidth=1.5,
                        label=lab_trim if ind==0 else None
                    )
                    if 'carat' in arch_type and metric=="perf" and '128' in name_color:
                        ax.annotate(f'{round(arr[-2],1)}x', 
                                    xy=(n[-2]-0.5,arr[-2]*2),
                                    annotation_clip=False,
                                    color='grey')
                        pass
                    ind += 1
                    
                if metric=="perf":
                    logger.warning('log scale')
                    ax.set_yscale("log")
                plt.grid(axis = 'y',linestyle = 'dotted', linewidth = 0.5, color = 'grey')
                newtick = np.array([n_list+['']]*len(arr_to_plot[0])).flatten()
                origtick = list(range(0,(len(n_list)+1)*len(arr_to_plot[0])))
                # logger.debug(origtick)
                # logger.debug(newtick)
                plt.xticks(origtick, newtick, rotation=90)
        else: 
            logger.error('not implemented')
            # only one set of n is enough as area is constant across n
            if single_n_slice == False:
                # color_bar_ = plt.bar(n_str[0]+width*i, np.array(arr_to_plot[0]), width, \
                #     color=color_map[arch_type][name_color], \
                #     alpha=opac_map[arch_type][name_opac], \
                #     label=trim_item(point[i],[False,True,False,False,True]))
                plt.plot(n_str[0], np.array(arr_to_plot[0]),
                    color=color_map[arch_type][name_color],
                    alpha=opac_map[arch_type][name_opac],
                    marker='o',
                    markersize=3,
                    linewidth=1.5,
                    label=trim_item(point[i],[False,True,False,False,True]))
                if arr_to_plot[0]>= ylim:
                    ax.annotate('%.1f' % (arr_to_plot[0]), xy=(n_str[0]+width*i-0.3*width, ylim+0.15), textcoords='data', rotation=0, annotation_clip=False)
            else:
                # color_bar_ = plt.bar(net_str[0]+width*i, np.array(arr_to_plot[0]), width, \
                #     color=color_map[arch_type][name_color], \
                #     alpha=opac_map[arch_type][name_opac], \
                #     label=trim_item(point[i],[False,True,False,False,True]))
                plt.plot(net_str[0]+width*i, np.array(arr_to_plot[0]),
                    color=color_map[arch_type][name_color],
                    alpha=opac_map[arch_type][name_opac],
                    marker='o',
                    markersize=3,
                    linewidth=1.5,
                    label=trim_item(point[i],[False,True,False,False,True]))
                if arr_to_plot[0]>= ylim:
                    ax.annotate('%.1f' % (arr_to_plot[0]), xy=(net_str[0]+width*i-0.3*width, ylim+0.15), textcoords='data', rotation=0, annotation_clip=False)
            # print(bcolors.WARNING + 'area or energy'+bcolors.ENDC)
            plt.xticks(
                width*np.arange(len(point)),\
                trim(
                    point=point,
                    ignore=ignore_vec_map.get(dir_name, default_map).get(metric, default_ignore_vec)
                ), \
                rotation=0)
    
        # color_bar += color_bar_

    # display yticks based on y_ticks_map in util
    ax2 = ax.twiny()
    ax2.set_xlim(ax.get_xlim())
    # new tick locations are mid points of n_str
    new_tick_locations = [i*(len(n_str)+1)+(len(n_str))/2 for i in range(len(mlperf_arr))]
    ax2.set_xticks(new_tick_locations)
    allnet = mlperf_arr[:-1]
    allnet.append('geomean')
    for net in allnet:
        if net == 'Resnet50':
            allnet[allnet.index(net)] = 'ResNet50'
    logger.debug(allnet)

    ax2.set_xticklabels(allnet)
    ax2.tick_params(axis='both', which='both', length=0)
    # ax2.xaxis.set_label_position('bottom')
    # ax2.xaxis.set_ticks_position('bottom')
    # ax2.spines['bottom'].set_position(('outward', 10))

    y_ticks_dict = y_ticks_map.get(dir_name, default_map)
    yt = y_ticks_dict.get(metric, default_map)
    if y_ticks_dict != default_map and \
        yt != default_map:
        ax.set_yticks(yt)
        ax.set_yticklabels(yt)

    ax.set_ylabel(axis_map[metric])
    ax.tick_params(axis='both', which='both', length=0)
    
    if legend_vec_map.get(dir_name, {}).get(metric):
        size_vec = legend_vec_map[dir_name][metric][0]
        loc = legend_vec_map[dir_name][metric][1]
        ncol=legend_vec_map[dir_name][metric][2]
        logger.info('relocating legend wedget')
        ax.legend(
                # color_bar,lab_arr, \
                bbox_to_anchor=size_vec, loc=loc,\
                mode="expand", borderaxespad=0, ncol=ncol)
    else:
        if metric in legend_map.get(dir_name, default_map):
            ncol = 3
            if 'chunk' in dir_name: ncol = 4
            ax.legend(
                # color_bar,lab_arr, \
                bbox_to_anchor=(0,1.02,1,0.1), loc="lower left",\
                mode="expand", borderaxespad=0, ncol=ncol)
        else: logger.info(f'no legend for {dir_name} {metric}')

    os.makedirs('plot/mlperf', exist_ok=True)
    metric_name = metric.replace('/', '_')
    metric_name = metric_name.replace(' ', '_')

    if dir_name == None:
            fig_name = f'plot/mlperf/{metric_name}'
    else:
        os.makedirs(f'plot/mlperf/{dir_name}', exist_ok=True)
        if single_n_slice == True:
            fig_name = f'plot/mlperf/{dir_name}/{metric_name}_n{nval}'
        else:
            fig_name = f'plot/mlperf/{dir_name}/{metric_name}'
    
    if universal_baseline:
        fig_name += '_uni_permodel.pdf'
    else:
        fig_name += '_permodel.pdf'

    fig.tight_layout()
    plt.savefig(fig_name, bbox_inches='tight', dpi=my_dpi, pad_inches=0.02)
    # print(bcolors.OKGREEN + f'Saved fig as {fig_name}' + bcolors.ENDC)



def plot_stacked_bar(point, metric, run_set, dir_name, nval):
    '''
    plot multiple set of bar plot
    '''
    # metric map
    if metric in yml_metric_map.keys():
        metric_in_yml_list = yml_metric_map[metric]
    else:
        assert False

    width=1.0/(len(point)+1)
    #if there is legend, longer image
    if metric in legend_map.get(dir_name, default_map): size_tuple_local = size_tuple
    else: size_tuple_local = (fig_w, fig_h_short)

    if size_map.get(dir_name, {}).get(metric): 
        size_tuple_local = size_map[dir_name][metric]

    # ---------plot----------
    fig, ax = plt.subplots(figsize=size_tuple_local, dpi=my_dpi)
    # ax.grid(color='grey', linestyle='--', linewidth=0.2, axis='both')
    
    ax.set_xlabel('')

    ylim = (clip_map.get(dir_name, default_map)).get(metric, 100)
    y_max = 1
    color_bar = []
    red_hatch = []
    lab_arch_arr = []
    lab_hatch_arr = []

    first_tlut_i_index = -1

    if 'dim' in dir_name: width_space = width * 1.35
    else: width_space = width

    for i in np.arange(len(point)):
        if first_tlut_i_index == -1 and 'carat' in point[i]:
            first_tlut_i_index = i

        name = point[i].split('/')[-1]
        name_color = name.split('_dram')[0]
        arch_type = point[i].split('/')[0]
        prev = None
        lab = trim_item(
            item=point[i], \
            ignore=ignore_vec_map.get(dir_name, default_map).get(metric, default_ignore_vec_area)
        )

        if dir_name != None and 'dim' in dir_name:
            name_opac = trim_item(
                item=point[i],
                ignore=default_for_opac,
                trim_only=True
                )
        else:
            name_opac = name.split('_dram')[0]

        for j in np.arange(len(metric_in_yml_list)):
            metric_in_yml = metric_in_yml_list[j]
            lab_hatch = metric_in_yml.split('/')[0]
            lab_arch = lab
            if lab_hatch in tlut_unique_component and 'systolic' in arch_type:
                continue
            
            arr = get_perf_one_arch_nval_across_all_nets(name, metric_in_yml, run_set, nval)
            arr_to_plot = arr

            color_ = color_map[arch_type][name_color]
            if 'inter' in name_color: color_ = bcolors.blue
            
            color_bar_ = plt.bar(width_space*i, np.array(arr_to_plot[0]), width, \
                color=color_, \
                alpha=opac_map[arch_type][name_opac],bottom=prev)
            red_hatch_ = plt.bar(width_space*i, np.array(arr_to_plot[0]), width, color='none', hatch=hatch_map[lab_hatch],\
                edgecolor = bcolors.red,bottom=prev,linewidth=0.3)
            plt.bar(width_space*i, np.array(arr_to_plot[0]), width, color='none',edgecolor = 'k', linewidth=0.3, bottom=prev)
            if prev == None:
                prev = np.array(arr_to_plot[0])
            else:
                prev += np.array(arr_to_plot[0])

            # adding arch (bar) artist and legend
            if j == 0: 
                color_bar += color_bar_
                lab_arch_arr.append(lab_arch)
            # adding hatch artist and legend
            if i == first_tlut_i_index: 
                red_hatch += red_hatch_
                lab_hatch_arr.append(lab_hatch)

        # if onchip overflows clip val annotate data
        if prev > ylim:
            ax.annotate('%.1f' % (prev), xy=(width_space*i-0.3*width, ylim+0.15), textcoords='data', rotation=0, annotation_clip=False)
        if y_max < prev:
            y_max = prev

    plt.xticks(
        width_space*np.arange(len(point)),\
        trim(
            point=point,
            ignore=ignore_vec_map.get(dir_name, default_map).get(metric, default_ignore_vec_area)
        ), 
        rotation=0)
    
    # display yticks based on y_ticks_map in util
    y_ticks_dict = y_ticks_map.get(dir_name, default_map)
    yt = y_ticks_dict.get(metric, default_map)
    if y_ticks_dict != default_map and \
        yt != default_map:
        ax.set_yticks(yt)
        ax.set_yticklabels(yt)
    
    ax.set_ylabel(axis_map[metric])
    
    if metric in legend_map.get(dir_name, default_map):
        if len(lab_hatch_arr) > 3:
            plt.legend(red_hatch,lab_hatch_arr, \
                bbox_to_anchor=(0,1.02,1,0.05), loc="lower left",\
                mode="expand", borderaxespad=0, ncol=3)
        else:
            plt.legend(red_hatch,lab_hatch_arr, \
                bbox_to_anchor=(0.15,1.02,0.7,0.05), loc="lower left",\
                mode="expand", borderaxespad=0, ncol=3)
    else: logger.debug(f'line766: no legend for {dir_name} {metric}')
        
    os.makedirs('plot/mlperf', exist_ok=True)
    metric_name = metric.replace('/', '_')
    metric_name = metric_name.replace(' ', '_')

    if dir_name == None:
        fig_name = f'plot/mlperf/{metric_name}.pdf'
    else:
        os.makedirs(f'plot/mlperf/{dir_name}', exist_ok=True)
        fig_name = f'plot/mlperf/{dir_name}/{metric_name}.pdf'

    if ylim != 100:
        final_ylim = ylim
    else:
        final_ylim = min(ylim,y_max)
    ax.set_ylim(top=final_ylim)
    fig.tight_layout()
    plt.savefig(fig_name, bbox_inches='tight', dpi=my_dpi, pad_inches=0.02)
    # print(bcolors.OKGREEN + f'Saved fig as {fig_name}' + bcolors.ENDC)

def plot_all_n_stacked_bar(point, metric, run_set, n_list, dir_name, single_n_slice, nval, mlperf_arr, universal_baseline):
    '''
    plot multiple stacked bar for all n
    '''
    # metric map
    if metric in yml_metric_map.keys():
        metric_in_yml_list = yml_metric_map[metric]
    else:
        assert False

    # reduce map
    if metric in reduce_map.keys():
        reduce_method = reduce_map[metric]
    else:
        assert False

    width=1.0/(len(point)+1)
    n_str = np.arange(len(n_list))
    net_str = np.arange(len(mlperf_arr))

    if single_n_slice:
        use_str = net_str
        xticks = mlperf_arr
    else: 
        if len(n_str) == 1:
            use_str = n_str
            xticks = trim(
            point=point,
            ignore=ignore_vec_map.get(dir_name, default_map).get(metric, default_ignore_vec_area))
        else:
            use_str = n_str
            xticks = n_list
    
    logger.debug('xticks=',xticks)
    logger.debug('use_str=',use_str)
    
    
    if metric in legend_map.get(dir_name, default_map): size_tuple_local = size_tuple
    else: size_tuple_local = (fig_w, fig_h_short)

    if size_map.get(dir_name, {}).get(metric): 
        size_tuple_local = size_map[dir_name][metric]

    # ---------plot----------
    fig, ax = plt.subplots(figsize=size_tuple_local, dpi=my_dpi)
    # ax.grid(color='grey', linestyle='--', linewidth=0.2, axis='both')

    logger.debug(metric_in_yml_list)

    color_bar = []
    red_hatch = []
    lab_arch_arr = []
    lab_hatch_arr = []
    first_tlut_i_index = -1
    for i in np.arange(len(point)):
        if first_tlut_i_index == -1 and 'carat' in point[i]:
            first_tlut_i_index = i
        name = point[i].split('/')[-1]
        name_color = name.split('_dram')[0]
        arch_type = point[i].split('/')[0]
        prev_list = None
        lab_arch = trim_item(
            item=point[i], \
            ignore=ignore_vec_map.get(dir_name, default_map).get(metric, default_ignore_vec)
        )

        if dir_name != None and 'dim' in dir_name:
            name_opac = trim_item(
                item=point[i],
                ignore=default_for_opac,
                trim_only=True
                )
        else:
            name_opac = name.split('_dram')[0]

        for j in np.arange(len(metric_in_yml_list)):
            metric_in_yml = metric_in_yml_list[j]
            lab_hatch = metric_in_yml.split('/')[0]
            var = metric_in_yml.split('/')[1]
            onchip = f'{var}/onchip/total/norm'
            if lab_hatch in tlut_unique_component and 'systolic' in arch_type:
                continue

            if single_n_slice == True:
                arr_percent = get_perf_one_arch_nval_across_all_nets(name, metric_in_yml, run_set, nval)
                arr_onchip = get_perf_one_arch_nval_across_all_nets(name, onchip, run_set, nval)
            else:
                arr_percent = get_perf_across_one_arch_from_all_n(name, n_list, metric_in_yml, reduce_method, run_set)
                arr_onchip = get_perf_across_one_arch_from_all_n(name, n_list, onchip, reduce_method, run_set)
            
            arr = np.array(arr_percent) * np.array(arr_onchip)
            arr_to_plot = list(arr)
            
            if 'format' in dir_name:
                # add some gap
                color_bar_ = plt.bar(use_str+width*i+0.05*int(int(i)/3), np.array(arr_to_plot), width, color=color_map[arch_type][name_color],alpha=opac_map[arch_type][name_opac], bottom=prev_list)
                red_hatch_ = plt.bar(use_str+width*i+0.05*int(int(i)/3), np.array(arr_to_plot), width, color='none', hatch=hatch_map[lab_hatch],edgecolor = bcolors.red,bottom=prev_list,linewidth=0.3)
                plt.bar(use_str+width*i+0.05*int(int(i)/3), np.array(arr_to_plot), width, color='none',edgecolor = 'k', linewidth=0.3, bottom=prev_list)
            else:
                color_bar_ = plt.bar(use_str+width*i, np.array(arr_to_plot), width, color=color_map[arch_type][name_color],alpha=opac_map[arch_type][name_opac], bottom=prev_list)
                red_hatch_ = plt.bar(use_str+width*i, np.array(arr_to_plot), width, color='none', hatch=hatch_map[lab_hatch],edgecolor = bcolors.red,bottom=prev_list,linewidth=0.3)
                plt.bar(use_str+width*i, np.array(arr_to_plot), width, color='none',edgecolor = 'k', linewidth=0.3, bottom=prev_list)
            
            if prev_list == None: 
                prev_list = arr_to_plot
                prev = np.array(arr_to_plot)
            else: 
                prev += np.array(arr_to_plot)
                prev_list = list(prev)

            # adding arch (bar) artist and legend
            if j == 0: 
                color_bar.append(color_bar_[0])
                lab_arch_arr.append(lab_arch)
            # adding hatch artist and legend
            if i == first_tlut_i_index: 
                red_hatch.append(red_hatch_[0])
                lab_hatch_arr.append(lab_hatch)

    if single_n_slice:
        plt.xticks(use_str+width*float(i)/2, xticks, rotation=0)
    else:
        if len(use_str) == 1:
            if 'format' in dir_name:
                x_ticks_pos = [use_str[0]+width*float(i)+0.05*int(int(i)/3) for i in np.arange(len(point))]
                x_ticks_pos = [x_ticks_pos[1],x_ticks_pos[4], x_ticks_pos[7], x_ticks_pos[10], x_ticks_pos[13], x_ticks_pos[16]]
                xticks = ['INT4', 'INT8', 'FP8', 'FP8_ACC16', 'BF16', 'BF16_ACC16']

                rot = 0
            else:
                x_ticks_pos = [use_str[0]+width*float(i) for i in np.arange(len(point))]

                rot = 25
        else:
            x_ticks_pos = use_str+width*float(i)/2
            rot = 0

        logger.debug(x_ticks_pos, xticks)
        plt.xticks(x_ticks_pos, xticks, rotation=rot)

    
    # display yticks based on y_ticks_map in util
    y_ticks_dict = y_ticks_map.get(dir_name, default_map)
    yt = y_ticks_dict.get(metric, default_map)
    if y_ticks_dict != default_map and \
        yt != default_map:
        ax.set_yticks(yt)
        ax.set_yticklabels(yt)

    ax.set_ylabel(axis_map[metric])
    
    if metric in legend_map.get(dir_name, default_map):
        ax.legend(
            # red_hatch+color_bar, lab_hatch_arr+lab_arch_arr, \
            red_hatch, lab_hatch_arr, \
            bbox_to_anchor=(0,1.02,1,0.1), loc="lower left",\
            mode="expand", borderaxespad=0, ncol=3)
    else: logger.debug(f'line966: no legend for {dir_name} {metric}')

    os.makedirs('plot/mlperf', exist_ok=True)
    metric_name = metric.replace('/', '_')
    metric_name = metric_name.replace(' ', '_')

    if dir_name == None:
        if single_n_slice == True:
            fig_name = f'plot/mlperf/{metric_name}_n{nval}' #TODO: pdf
        else:
            fig_name = f'plot/mlperf/{metric_name}'#TODO: pdf
    else:
        os.makedirs(f'plot/mlperf/{dir_name}', exist_ok=True)
        if single_n_slice == True:
            fig_name = f'plot/mlperf/{dir_name}/{metric_name}_n{nval}'#TODO: pdf
        else:
            fig_name = f'plot/mlperf/{dir_name}/{metric_name}'#TODO: pdf
    
    if universal_baseline:
        fig_name += '_uni.pdf'
    else:
        fig_name += '.pdf'

    fig.tight_layout()
    plt.savefig(fig_name, bbox_inches='tight', dpi=my_dpi, pad_inches=0.02)
    # print(bcolors.OKGREEN + f'Saved fig as {fig_name}' + bcolors.ENDC)

if __name__ == "__main__":
    parser = construct_argparser()
    args = parser.parse_args()

    if 'area_partition' in args.metric or 'area_breakdown' in args.metric:
        assert args.single_n_slice == True, 'area partition requires single_n_slice flag.'
    if args.single_n_slice:
        assert args.nval != None
        args.nval = int(args.nval)
        logger.info(f'*** single n slice at: {args.nval}')

    logger.info('*** mlperf: ', args.mlperf)
    logger.info('*** dir: ', args.dirname)
    os.makedirs('plot/log', exist_ok=True)
    timestr = time.strftime("%Y%m%d-%H%M%S")
    if args.dirname != None:
        os.makedirs(f'plot/log/{args.dirname}', exist_ok=True)
        yml_name = f'./plot/log/{args.dirname}/plot_{timestr}.yml'
    else:
        yml_name = f'./plot/log/plot_{timestr}.yml'

    # === peak perf lambda functions ===
    get_sys_peak_flops_per_sec = lambda arch_dict, cycle:(arch_dict['architecture']['compu']['num_instances'][0] * \
        arch_dict['architecture']['compu']['num_instances'][1] * 2.0) / cycle * \
        arch_dict['architecture']['frequency'] * 10**6
    # get_sys_peak_flops_per_sec = lambda arch_dict, cycle:(arch_dict['architecture']['compu']['num_instances'][0] * \
    #     arch_dict['architecture']['compu']['num_instances'][1] * 2.0 + arch_dict['architecture']['compu']['num_instances'][1]) / cycle * \
    #     arch_dict['architecture']['frequency'] * 10**6

    get_tlut_peak_flops_per_sec = lambda arch_dict, chunk_size, data_length: 2 * \
        arch_dict['architecture']['compu']['num_instances'][0] * \
        (float(chunk_size)/data_length) * arch_dict['architecture']['frequency'] * 10**6
    
    get_cris_scaled_similarity = lambda wl_dict: (1 + (wl_dict['workload'][list(wl_dict['workload'].keys())[0]]['N']-1) * \
        (1-(wl_dict['workload'][list(wl_dict['workload'].keys())[0]]['similarity']))) / \
            (wl_dict['workload'][list(wl_dict['workload'].keys())[0]]['N'])
    get_cris_peak_flops_per_sec = lambda arch_dict, wl_dict, cycle:(arch_dict['architecture']['compu']['num_instances'][0] * \
        arch_dict['architecture']['compu']['num_instances'][1] * 2.0) / cycle * \
        arch_dict['architecture']['frequency'] * 10**6 / get_cris_scaled_similarity(wl_dict)
    # === peak perf lambda functions ===

    assert args.file != None or len(args.point) != 0, 'no point specified'
    if args.file != None: # override args.point
        # parsing list of run names
        with open(args.file) as file:
            args.point = [line.rstrip() for line in file]
    logger.info('*** points: ', args.point)

    if args.use_universal_baseline:
        logger.info('*** using universal baseline (default)')
    else:
        logger.info('*** normalizing in terms of group')

    # === construct run set ===
    # print(bcolors.OKCYAN + f'Constructing run set...' + bcolors.ENDC)
    # parse 1: get n
    n_set = set()
    if args.single_n_slice == True:
        n_max = args.nval
        n_set.add(args.nval)
    else:
        n_max = 0
        n_set = set()
        for prefix in args.point:
            point__ = prefix
            prefix = prefix.split('/')[-1]
            dir_append = point__.replace(prefix, '')
            myrootdir = f'runs/{dir_append}/'
            dir_set = []
            reg_name = re.compile(re.escape(prefix) + r'\S*_c\d{1,3}_n\d{1,4}')
            dir_name = next(os.walk(myrootdir))[1]
            for dirnames in dir_name:
                if reg_name.match(dirnames):
                    find_n = re.search(r'_n(\d+)', dirnames)
                    assert find_n
                    n = int(find_n.group(1))
                    if n > n_max: n_max = n
                    if n not in n_set and n in args.nvalid: n_set.add(n)
    n_list = list(n_set)
    n_list.sort()
    logger.info('*** n: ', n_list)

    # parse 2: get group
    if args.group_by == 'n':
        group_list = n_list
    elif args.group_by == 'format':
        group_set = set()
        for point in args.point:
            if 'int8' in point: group_set.add('int8')
            elif 'int16' in point: group_set.add('int16')
            elif 'acc16' in point: group_set.add('bf16_acc16')
            elif 'bf16' in point: group_set.add('bf16')
        group_list = list(group_set)
        group_list.sort(reverse=True)
        logger.info(f'*** group by {group_list}')
    elif args.group_by == 'design':
        group_list = args.point
    else:
        assert False

    run_set = OrderedDict()
    if args.group_by == 'n':
        # index 1: n
        for n in group_list:
            n_dict = OrderedDict()
            postfix = f'_n{n}'
            # index2: arch
            for prefix in args.point:
                point__ = prefix
                prefix = prefix.split('/')[-1]
                dir_append = point__.replace(prefix, '')
                myrootdir = f'runs/{dir_append}/'
                group = OrderedDict()
                # index3: network
                for net in args.mlperf:
                    reg_name = re.compile(re.escape(prefix) + '_' + net + r'_c\d{1,3}' + re.escape(postfix) + r'$')
                    dir_name = next(os.walk(myrootdir))[1]
                    for dirnames in dir_name:
                        if reg_name.match(dirnames):
                            dirnames = f'{dir_append}/{dirnames}'
                            dict_ = OrderedDict()
                            dict_['dir'] = dirnames
                            group[net] = dict_
                n_dict[prefix] = group
            run_set[n] = n_dict

    elif args.group_by == 'format':
        # index 1: format
        for format_ in group_list:
            format_chunks = format_.split('_')
            format_chunks = ['_' + c for c in format_chunks]
            format_dict = OrderedDict()
            # index2: arch
            for point in args.point:
                if_continue = False
                for c in format_chunks:
                    if c not in point: 
                        if_continue = True
                        break
                if if_continue: continue
                if format_ == 'bf16' and '_acc16' in point: continue 

                arch_name = point.split('/')[-1]
                dir_append = point.replace(arch_name, '')
                myrootdir = f'runs/{dir_append}/'
                group = OrderedDict()
                # index3: network
                for net in args.mlperf:
                    reg_name = re.compile(re.escape(arch_name) + '_' + net + r'_c\d{1,3}_n'+re.escape(str(args.nval)))
                    dir_name = next(os.walk(myrootdir))[1]
                    for dirnames in dir_name:
                        if reg_name.match(dirnames):
                            dirnames = f'{dir_append}/{dirnames}'
                            dict_ = OrderedDict()
                            dict_['dir'] = dirnames
                            group[net] = dict_
                format_dict[arch_name] = group
            run_set[format_] = format_dict
    elif args.group_by == 'design':
        # index 1: n
        for point in group_list:
            point_dict = OrderedDict()
            for n in n_list:
                postfix = f'_n{n}'
                # index2: arch
                prefix = point.split('/')[-1]
                logger.debug(prefix)
                dir_append = point.replace(prefix, '')
                myrootdir = f'runs/{dir_append}/'
                logger.debug(myrootdir)
                group = OrderedDict()
                # index3: network
                for net in args.mlperf:
                    reg_name = re.compile(re.escape(prefix) + '_' + net + r'_c\d{1,3}' + re.escape(postfix) + r'$')
                    dir_name = next(os.walk(myrootdir))[1]
                    for dirnames in dir_name:
                        if reg_name.match(dirnames):
                            dirnames = f'{dir_append}/{dirnames}'
                            dict_ = OrderedDict()
                            dict_['dir'] = dirnames
                            group[net] = dict_
                point_dict[prefix] = group
            run_set[point] = point_dict

    else: assert False

    # === register needed results ===
    perf_read, cost_read, other, cost_layer_read = gen_read_sets(args.metric)

    # === fill in perf results ===
    # print(bcolors.OKCYAN + f'Reading yml...' + bcolors.ENDC)
    ref_arch_name = args.point[0].split('/')[-1]
    run_set = read_perf_cost(run_set, perf_read, cost_read, other, cost_layer_read, ref_arch_name, args.use_universal_baseline)
    
    if args.overlay:
        # rescale metric/norm
        # print(bcolors.WARNING + f'rescale metric/norm' + bcolors.ENDC)
        run_set = rescale(run_set)
            
    yaml_overwrite(yml_name, run_set) # TODO: remove for speed and space
    # exit()

    # print(bcolors.OKCYAN + f'ploting...' + bcolors.ENDC)
    use_line = args.line
    for metric in args.metric:
        if single_multiple_map[metric] == 'single':
            if use_line:
                if args.per_model:
                    logger.info('using plot_all_n_line per model')
                    plot_all_n_line_permodel(
                        point=args.point,
                        metric=metric, 
                        run_set=run_set, 
                        n_list=n_list, 
                        dir_name=args.dirname, 
                        single_n_slice=args.single_n_slice, 
                        mlperf_arr=([mlperf_name_map[name] for name in args.mlperf]+['geomean']),
                        nval=args.nval,
                        universal_baseline=args.use_universal_baseline,
                        overlay=args.overlay
                    )
                else:
                    logger.info('using plot_all_n_line!')
                    plot_all_n_line(
                        point=args.point,
                        metric=metric, 
                        run_set=run_set, 
                        n_list=n_list, 
                        dir_name=args.dirname, 
                        single_n_slice=args.single_n_slice, 
                        mlperf_arr=([mlperf_name_map[name] for name in args.mlperf]+['geomean']),
                        nval=args.nval,
                        universal_baseline=args.use_universal_baseline,
                        overlay=args.overlay
                    )
            else:
                logger.info('using plot_all_n!')
                plot_all_n(
                    point=args.point,
                    metric=metric, 
                    run_set=run_set, 
                    n_list=n_list, 
                    dir_name=args.dirname, 
                    single_n_slice=args.single_n_slice, 
                    mlperf_arr=([mlperf_name_map[name] for name in args.mlperf]+['geomean']),
                    nval=args.nval,
                    universal_baseline=args.use_universal_baseline
                )
        elif single_multiple_map[metric] == 'multiple_single_n':
            logger.info('using plot_stacked_bar!')
            if args.group_by != 'format':
                plot_stacked_bar(
                    point=args.point,
                    metric=metric,
                    run_set=run_set,
                    dir_name=args.dirname,
                    nval=args.nval
                )
            else:
                assert False
        else:
            logger.info('using plot_all_n_stacked_bar!')
            plot_all_n_stacked_bar(
                point=args.point,
                metric=metric, 
                run_set=run_set, 
                n_list=group_list, 
                dir_name=args.dirname, 
                single_n_slice=args.single_n_slice, 
                mlperf_arr=([mlperf_name_map[name] for name in args.mlperf]+['geomean']),
                nval=args.nval,
                universal_baseline=args.use_universal_baseline
            )