from collections import OrderedDict
import sys, os
sys.path.append(os.getcwd())
from utils import yaml_load, bcolors,yaml_overwrite
import argparse
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from matplotlib.pyplot import cm
import os
import re
from scipy.stats import gmean
import math
import time
from loguru import logger
from plot_util import tlut_unique_component,cris_unique_component, cris_no_component, hatch_map, ls_arr, marker_arr, \
    sys_breakdown_map_sram_and_rest, carat_breakdown_map_sram_and_rest, usys_breakdown_map_sram_and_rest, cris_breakdown_map_sram_and_rest, \
    sys_breakdown_map, carat_breakdown_map, usys_breakdown_map, cris_breakdown_map, \
    hatch_map_cg, size_map, legend_vec_map, \
    color_map, opac_map, \
    yml_metric_map, reduce_map, single_multiple_map, \
    partition_layerwise_reduction_map, axis_map, \
    mlperf_name_map, legend_map, clip_map, default_map, \
    ignore_vec_map, default_ignore_vec, default_ignore_vec_area,default_for_opac, \
    my_dpi, fig_h, fig_h_short, fig_w, size_tuple, \
    y_ticks_map
# import plotly.express as px
# import plotly.io as pio
# pio.renderers.default = "notebook_connected"
# import plotly.graph_objects as go
# import plotly.figure_factory as ff

logger.disable("__main__")
PEAK = 'peak flops per sec'

# matplotlib settings
font = {'serif':'Helvetica Neue', 'size': 6}
matplotlib.rc('font', **font)
matplotlib.rcParams['hatch.linewidth'] = 0.3
matplotlib.rcParams['lines.linewidth'] = 0.3
matplotlib.rcParams['axes.linewidth'] = 0.3
matplotlib.rcParams['axes.titlesize'] = 5
matplotlib.rcParams['xtick.major.width'] = 0.3
matplotlib.rcParams['xtick.labelsize'] = 5
matplotlib.rcParams['ytick.labelsize'] = 5
matplotlib.rcParams['ytick.major.width'] = 0.3
matplotlib.rcParams['xtick.major.size'] = 2
matplotlib.rcParams['ytick.major.size'] = 2
matplotlib.rcParams['legend.fontsize'] = 5
# matplotlib.rcParams['legend.handlelength'] = 2
matplotlib.rcParams['legend.labelspacing'] = 0.25

def construct_argparser():
    parser = argparse.ArgumentParser(description='plot')
    parser.add_argument('-p',
                        '--point',
                        nargs='+',
                        help='config',
                        )


    parser.add_argument('--use_universal_baseline', dest='use_universal_baseline', action='store_true')
    parser.add_argument('--no-use_universal_baseline', dest='use_universal_baseline', action='store_false')
    parser.set_defaults(use_universal_baseline=False)

    parser.add_argument('-d',
                        '--dirname',
                        help='dir name',
                        default=None)

    parser.add_argument('--line',
                        help='draw line plot',
                        action='store_true',
                        default=False)

    parser.add_argument('-n',
                        '--mlperf',
                        nargs='+',
                        help='networks in mlperf\n',
                        default=['resnet50', 'ssd300_vgg16', 'UNet''bert_base_uncased', 'dlrm', 'RNNT', ]
                        )
    
    parser.add_argument('--nval',
                        help='plot n=nval only',
                        choices=[
                            '1','2','4','8','16',\
                            '32','64','128','256','512',\
                            '1024','2048','4096'\
                        ],
                        default='256')

    parser.add_argument('-m',
                        '--metric',
                        nargs='+',
                        help='perf to plot',
                        choices=['perf', 'perf_norm', 'util_impl', \
                            'rt_abs', 'rt_norm', \
                            'area_total', 'area_onchip', 'area_dram', 'area_sram', 'area_compute',\
                            'area_partition', 'area_breakdown', 'energy_partition', 'energy_breakdown', \
                            'area_breakdown_sram_and_rest', 'energy_breakdown_sram_and_rest', \
                            'power_breakdown', 'power_breakdown_nosram', 'power_breakdown_nosramfifo', \
                            'area_breakdown_nosram', 'area_breakdown_nosramfifo', \
                            'energy_breakdown_nosram', 'energy_breakdown_nosramfifo', \
                            'area_breakdown_array_level', 'energy_breakdown_array_level', \
                            'power_total', 'power_onchip', 'power_dram', 'power_sram', 'power_compute',\
                            'energy_total', 'energy_onchip', 'energy_dram', 'energy_sram', 'energy_compute',\
                            'throughput/area', 'throughput/energy', 'throughput/power',\
                            'flops/area', 'flops/energy', 'flops/power'
                            ]
                        )

    parser.add_argument('-q',
                        '--query',
                        nargs='+',
                        help='query name\n',
    )
    parser.add_argument('-qm',
                        '--qm',
                        nargs='+',
                        help='query metric\n',
                        default=['flops per sec', \
                            'throughput/energy', \
                            ]
                        )
    parser.add_argument('--overlay',
                        help='overlay benchmark on plot',
                        action='store_true',
                        default=False)
    
    parser.add_argument('--pie',
                        help='draw pie plot',
                        action='store_true',
                        default=False)

    parser.add_argument('--norescale',
                        help='do not rescale',
                        action='store_true',
                        default=False)
    
    parser.add_argument('--carat_only',
                        help='group by carat only',
                        action='store_true',
                        default=False)
    
    parser.add_argument('--postfix',
                        required=False,
                        )
    return parser

def trim_item(item, ignore, trim_only=False):
    #0: ignore_inter_intra, 
    #1: ignore_format, 
    #2: ignore_chunk, 
    #3: ignore_sys_dimension, 
    #4: ignore_tlut_dimension
    new_item = item.split('/')[-1]
    new_item = new_item.replace('_edge', '')
    new_item = new_item.replace('_s4', '')
    new_item = new_item.replace('_dram2','')
    new_item = new_item.replace('_dram3','')
    new_item = new_item.replace('_dram4','')
    new_item = new_item.replace('_dram128','')
    new_item = new_item.replace('_dram256','')
    new_item = new_item.replace('_dram51','')
    new_item = new_item.replace('_dram25','')
    new_item = new_item.replace('_sram1','')
    if not trim_only: 
        new_item = new_item.replace('systolic', 'bSA')
        new_item = new_item.replace('carat', 'Carat')
        new_item = new_item.replace('uSystolic', 'uSA')
        new_item = new_item.replace('cris', 'RIS')
    
    if ignore[0]:
        new_item = new_item.replace('_inter', '')
        new_item = new_item.replace('_intra', '')
    else:
        if not trim_only:
            new_item = new_item.replace('_inter', '-r')
            new_item = new_item.replace('_intra', '-a')

    if ignore[1]:
        new_item = new_item.replace('_int8','')
        new_item = new_item.replace('_int4','')
        new_item = new_item.replace('_bf16','')
        new_item = new_item.replace('_fp8','')
        new_item = new_item.replace('_acc16','')
        new_item = new_item.replace('_acc24','')
        new_item = new_item.replace('_int16','')
    else:
        if not trim_only and ignore[4]:
            new_item = new_item.replace('_int8','$_{int8}$')
            new_item = new_item.replace('_bf16','$_{bf16}$')
            new_item = new_item.replace('_acc16','$_{\_acc16}$')
            new_item = new_item.replace('_int16','$_{int16}$')
            new_item = new_item.replace('_fp8','$_{fp8}$')
            new_item = new_item.replace('_int4','$_{int4}$')

    if ignore[2]:
        new_item = new_item.replace('_t1', '')
        new_item = new_item.replace('_t2', '')
        new_item = new_item.replace('_t4', '')
        new_item = new_item.replace('_t8', '')
    else: # fixing typo c for chunk
        if not trim_only and ignore[4]:
            new_item = new_item.replace('_t1', '$_{C1}$')
            new_item = new_item.replace('_t2', '$_{C2}$')
            new_item = new_item.replace('_t4', '$_{C4}$')
            new_item = new_item.replace('_t8', '$_{C8}$')
        else:
            new_item = new_item.replace('_t1', '_C1')
            new_item = new_item.replace('_t2', '_C2')
            new_item = new_item.replace('_t4', '_C4')
            new_item = new_item.replace('_t8', '_C8')

    if ignore[3]:
        new_item = new_item.replace('_8x8', '')
        new_item = new_item.replace('_4x4', '')
        new_item = new_item.replace('_32x16', '')
        new_item = new_item.replace('_64x16', '')
        new_item = new_item.replace('_16x16', '')
        new_item = new_item.replace('_32x32', '')
        new_item = new_item.replace('_128x16', '')
        new_item = new_item.replace('_256x16', '')
        new_item = new_item.replace('_64x32', '')
        new_item = new_item.replace('_64x64', '')
    else:
        if not trim_only:
            new_item = new_item.replace('_8x8', ' (8)')
            new_item = new_item.replace('_4x4', ' (4)')
            new_item = new_item.replace('_16x16', ' (16)')
            new_item = new_item.replace('_32x32', ' (32)')
            new_item = new_item.replace('_64x64', ' (64)')

    if ignore[4]:
        new_item = new_item.replace('_32', '')
        new_item = new_item.replace('_64', '')
        new_item = new_item.replace('_128', '')
        new_item = new_item.replace('_256', '')
        new_item = new_item.replace('_512', '')
        new_item = new_item.replace('_1024', '')
        new_item = new_item.replace('_2048', '')
        new_item = new_item.replace('_4096', '')
    else:
        if not trim_only:
            new_item = new_item.replace('_32', ' (32)')
            new_item = new_item.replace('_64', ' (64)')
            new_item = new_item.replace('_128', ' (128)')
            new_item = new_item.replace('_256', ' (256)')
            new_item = new_item.replace('_512', ' (512)')
            new_item = new_item.replace('_1024', ' (1024))')
            new_item = new_item.replace('_2048', ' (2048)')
            new_item = new_item.replace('_4096', ' (4096)')

    return r'{}'.format(new_item)

def gen_read_sets(data):
    logger.debug('*** metric: ', data)
    perf_read = set()
    cost_read = set()
    other = set()
    cost_layer_read = set()

    for metric in data:
        if metric == 'perf' or metric == 'perf_norm':
            perf_read.add('flops per sec')
        if metric == 'perf_norm':
            perf_read.add(PEAK)
            other.add(metric)
        if metric == 'util_impl':
            perf_read.add('utilization/impl')
        if metric == 'rt_abs' or metric == 'rt_norm':
            perf_read.add('runtime/impl')

        if metric == 'power_total':
            cost_read.add('power/total/total')
        if metric == 'power_onchip':
            cost_read.add('power/onchip/total')
        if metric == 'power_sram':
            cost_read.add('power/sram/total')
        if metric == 'power_dram':
            cost_read.add('power/dram/total')
        if metric == 'power_compute':
            cost_read.add('power/compute/total')
        
        if metric == 'area_total' or metric == 'area_onchip':
            cost_read.add('area/onchip')
        if metric == 'area_sram':
            cost_read.add('area/sram')
        if metric == 'area_total' or metric == 'area_dram':
            cost_read.add('area/dram')
        if metric == 'area_compute':
            cost_read.add('area/compute')
        if metric == 'area_total':
            other.add(metric)

        if metric == 'energy_total':
            cost_read.add('energy/total/total')
        if metric == 'energy_onchip':
            cost_read.add('energy/onchip/total')
        if metric == 'energy_sram':
            cost_read.add('energy/sram/total')
        if metric == 'energy_dram':
            cost_read.add('energy/dram/total')
        if metric == 'energy_compute':
            cost_read.add('energy/compute/total') 
        
        if metric == 'flops/energy':
            perf_read.add('flops')
            cost_read.add('energy/onchip/total')
            other.add(metric)
        if metric == 'flops/area':
            perf_read.add('flops')
            cost_read.add('area/onchip')
            cost_read.add('area/sram')
            cost_read.add('area/dram')
            cost_read.add('area/compute')
            other.add(metric)
        if metric == 'flops/power':
            perf_read.add('flops')
            cost_read.add('power/onchip/total')
            other.add(metric)

        if metric == 'throughput/energy':
            perf_read.add('flops per sec')
            cost_read.add('energy/onchip/total')
            cost_read.add('energy/sram/total')
            other.add(metric)
        if metric == 'throughput/area':
            perf_read.add('flops per sec')
            cost_read.add('area/onchip')
            cost_read.add('area/sram')
            cost_read.add('area/dram')
            cost_read.add('area/compute')
            other.add(metric)
        if metric == 'throughput/power':
            perf_read.add('flops per sec')
            cost_read.add('power/onchip/total')
            other.add(metric)
        
        if metric == 'area_partition' or metric == 'area_breakdown' or \
            metric == 'area_breakdown_sram_and_rest':
            cost_read.add('area/onchip')
            cost_read.add('area/sram')
            cost_layer_read.add('compu/area')
            cost_layer_read.add('ififo/area')
            cost_layer_read.add('wfifo/area')
            cost_layer_read.add('ofifo/area')
            cost_layer_read.add('oaccu/area')
            cost_layer_read.add('isram/area')
            cost_layer_read.add('wsram/area')
            cost_layer_read.add('osram/area')
            cost_layer_read.add('waccu/area')
            cost_layer_read.add('itemp/area')
            cost_layer_read.add('osmux/area')
            cost_layer_read.add('bpipe/area')

        if metric == 'area_breakdown_nosram':
            cost_read.add('area/onchip')
            # cost_read.add('area/sram')
            cost_layer_read.add('compu/area')
            cost_layer_read.add('ififo/area')
            cost_layer_read.add('wfifo/area')
            cost_layer_read.add('ofifo/area')
            cost_layer_read.add('oaccu/area')
            # cost_layer_read.add('isram/area')
            # cost_layer_read.add('wsram/area')
            # cost_layer_read.add('osram/area')
            cost_layer_read.add('waccu/area')
            cost_layer_read.add('itemp/area')
            cost_layer_read.add('osmux/area')
            cost_layer_read.add('bpipe/area')

        if metric == 'area_breakdown_array_level':
            cost_read.add('area/onchip')
            # cost_read.add('area/sram')
            cost_layer_read.add('compu/area')
            cost_layer_read.add('ififo/area')
            cost_layer_read.add('wfifo/area')
            cost_layer_read.add('ofifo/area')
            cost_layer_read.add('oaccu/area')
            # cost_layer_read.add('isram/area')
            # cost_layer_read.add('wsram/area')
            # cost_layer_read.add('osram/area')
            cost_layer_read.add('waccu/area')
            cost_layer_read.add('itemp/area')
            cost_layer_read.add('osmux/area')
            cost_layer_read.add('bpipe/area')

        if metric == 'area_breakdown_nosramfifo':
            cost_read.add('area/onchip')
            # cost_read.add('area/sram')
            cost_layer_read.add('compu/area')
            cost_layer_read.add('oaccu/area')
            cost_layer_read.add('waccu/area')
            cost_layer_read.add('itemp/area')
            cost_layer_read.add('osmux/area')
            cost_layer_read.add('bpipe/area')

        if metric == 'energy_partition' or metric == 'energy_breakdown' or \
            metric == 'energy_breakdown_sram_and_rest':

            cost_read.add('energy/onchip/total')
            cost_layer_read.add('compu/energy/total')
            cost_layer_read.add('ififo/energy/total')
            cost_layer_read.add('wfifo/energy/total')
            cost_layer_read.add('ofifo/energy/total')
            cost_layer_read.add('oaccu/energy/total')
            cost_layer_read.add('isram/energy/total')
            cost_layer_read.add('wsram/energy/total')
            cost_layer_read.add('osram/energy/total')
            cost_layer_read.add('waccu/energy/total')
            cost_layer_read.add('itemp/energy/total')
            cost_layer_read.add('osmux/energy/total')
            cost_layer_read.add('bpipe/energy/total')
            # TODO: add more
        if metric == 'energy_breakdown_nosram':
            cost_read.add('energy/onchip/total')
            cost_layer_read.add('compu/energy/total')
            cost_layer_read.add('ififo/energy/total')
            cost_layer_read.add('wfifo/energy/total')
            cost_layer_read.add('ofifo/energy/total')
            cost_layer_read.add('oaccu/energy/total')
            cost_layer_read.add('waccu/energy/total')
            cost_layer_read.add('itemp/energy/total')
            cost_layer_read.add('osmux/energy/total')
            cost_layer_read.add('bpipe/energy/total')
        if metric == 'energy_breakdown_array_level':
            cost_read.add('energy/onchip/total')
            cost_layer_read.add('compu/energy/total')
            cost_layer_read.add('ififo/energy/total')
            cost_layer_read.add('wfifo/energy/total')
            cost_layer_read.add('ofifo/energy/total')
            cost_layer_read.add('oaccu/energy/total')
            cost_layer_read.add('waccu/energy/total')
            cost_layer_read.add('itemp/energy/total')
            cost_layer_read.add('osmux/energy/total')
            cost_layer_read.add('bpipe/energy/total')

        if metric == 'energy_breakdown_nosramfifo':
            cost_read.add('energy/onchip/total')
            cost_layer_read.add('compu/energy/total')
            cost_layer_read.add('oaccu/energy/total')
            cost_layer_read.add('waccu/energy/total')
            cost_layer_read.add('itemp/energy/total')
            cost_layer_read.add('osmux/energy/total')
            cost_layer_read.add('bpipe/energy/total')

        
        if metric == 'power_breakdown':
            cost_read.add('power/onchip/total')
            cost_layer_read.add('compu/power/total')
            cost_layer_read.add('ififo/power/total')
            cost_layer_read.add('wfifo/power/total')
            cost_layer_read.add('ofifo/power/total')
            cost_layer_read.add('oaccu/power/total')
            cost_layer_read.add('isram/power/total')
            cost_layer_read.add('wsram/power/total')
            cost_layer_read.add('osram/power/total')
            cost_layer_read.add('waccu/power/total')
            cost_layer_read.add('itemp/power/total')
            cost_layer_read.add('osmux/power/total')
            cost_layer_read.add('bpipe/power/total')

    if len(perf_read) == 0 and len(cost_read) == 0 and len(other) == 0 and len(cost_layer_read) == 0: exit()
    return perf_read, cost_read, other, cost_layer_read

def read_perf_cost(run_set, perf_read, cost_read, other, cost_layer_read, ref_arch_name, use_universal_baseline, 
                   sys_bmap, usys_bmap, carat_bmap, cris_bmap):
    first_arch_in_all = True
    for group_dict in run_set.values():
        first_arch_in_group = True
        for arch_tuple in group_dict.items():
            # logger.debug(arch_tuple)
            arch = arch_tuple[1]
            if first_arch_in_all == True:
                first_arch_in_all = False
                ref_arch = arch
            # optional update if not universalbaseline
            if first_arch_in_group == True and use_universal_baseline == False: 
                first_arch_in_group = False
                ref_arch = arch
            
                
            for net_tuple in arch.items():
                net_name = net_tuple[0]
                net_dict = net_tuple[1]
                dir = net_dict['dir']
                logger.info(dir)
                if len(perf_read) != 0:
                    # process performance
                    perf_yml_file = f'runs/{dir}'+'/output/performance/workloadperf.summary.yaml'
                    perf_dict = yaml_load(perf_yml_file)
                    for perf in perf_read:
                        if perf == PEAK:
                            arch_yml_file = f'runs/{dir}'+'/input/architecture.yaml'
                            arch_dict = yaml_load(arch_yml_file)
                            wl_yml_file = f'runs/{dir}'+'/input/workload.yaml'
                            wl_dict = yaml_load(wl_yml_file)['workload']
                            for layer in wl_dict.values():
                                cycle = layer['cycle']
                                break
                            if 'ystolic' in dir: peak = get_sys_peak_flops_per_sec(arch_dict, cycle)
                            elif 'carat' in dir and 'i1' in dir:
                                chunk_size = math.log2(arch_dict['architecture']['compu']['num_instances'][1]) + 1
                                # logger.debug(chunk_size)
                                if 'fp8' in dir:
                                    data_length = 4
                                else: assert False
                                peak = get_tlut_peak_flops_per_sec(arch_dict, chunk_size, data_length)
                            elif 'carat' in dir and 'idle8' in dir: 
                                chunk_size = math.log2(arch_dict['architecture']['compu']['num_instances'][1])
                                if 'fp8' in dir:
                                    data_length = 4
                                else: assert False
                                peak = get_tlut_peak_flops_per_sec(arch_dict, chunk_size, data_length)
                            elif 'carat' in dir: 
                                chunk_size = math.log2(arch_dict['architecture']['compu']['num_instances'][1])+1
                                if 'fp8' in dir:
                                    data_length = 4
                                else: assert False
                                peak = get_tlut_peak_flops_per_sec(arch_dict, chunk_size, data_length)
                            elif 'cris' in dir:
                                wl_dict_ = yaml_load(wl_yml_file)
                                peak = get_cris_peak_flops_per_sec(arch_dict, wl_dict_, cycle)
                            else: assert False, f'{dir}'
                            net_dict[perf] = peak
                            continue
                        level = perf.split('/')
                        thing = perf_dict['overall']
                        for l in level:
                            thing = thing[l]
                        net_dict[perf] = thing
                        net_dict[perf + '/norm'] = thing / ref_arch[net_name][perf]
                
                if len(cost_read) != 0:
                    # process cost
                    cost_yml_file = f'runs/{dir}'+'/output/cost/workloadcost.summary.yaml'
                    cost_dict = yaml_load(cost_yml_file)
                    for cost in cost_read:
                        level = cost.split('/')
                        thing = cost_dict['overall']
                        for l in level:
                            thing = thing[l]
                        if (cost == 'area/onchip') and 'uSystolic' in arch_tuple[0]:
                            logger.debug(bcolors.FAIL + "usys " + cost + bcolors.ENDC)
                            thing -= cost_dict['overall']['area']['sram']
                        if (cost == 'energy/onchip/total') and 'uSystolic' in arch_tuple[0]:
                            logger.debug(bcolors.FAIL + "usys " + cost + bcolors.ENDC)
                            thing -= cost_dict['overall']['energy']['sram']['total']
                        if (cost == 'power/onchip/total') and 'uSystolic' in arch_tuple[0]:
                            logger.debug(bcolors.FAIL + "usys " + cost + bcolors.ENDC)
                            thing -= cost_dict['overall']['power']['sram']['total']
                        net_dict[cost] = thing
                        net_dict[cost + '/norm'] = thing / ref_arch[net_name][cost]
                
                if len(cost_layer_read) != 0:
                    # process cost partition
                    cost_layer_yml_file = f'runs/{dir}'+'/output/cost/workloadcost.yaml'
                    cost_layer_dict = yaml_load(cost_layer_yml_file)
                    for dict_ in cost_layer_dict.items():
                        if dict_[0] in ['technology', 'frequency', 'overall']:
                            continue
                        else: break

                    # init breakdown record
                    for cost in cost_layer_read:
                        area_or_energy = cost.split('/')[1]
                        for breakdown in hatch_map_cg.keys():
                            cost_bd = breakdown + f'/{area_or_energy}'
                            if area_or_energy == 'energy':
                                cost_bd += '/total'
                            if area_or_energy == 'power':
                                cost_bd += '/total'
                            net_dict[cost_bd] = 0

                    for cost in cost_layer_read:
                        if 'uSystolic' in dir and 'sram' in cost: 
                            # logger.warning('uSystolic sram cost is not included in the total cost')
                            continue
                        

                        level = cost.split('/')
                        partition = level[0]
                        # logger.debug(partition)
                        # logger.debug(dir)
                        area_or_energy = level[1]
                        
                        if partition in tlut_unique_component and 'carat' not in dir:
                            logger.debug(f'{partition} not in {dir}')
                            continue
                        if partition in cris_unique_component and 'cris' not in dir:
                            logger.debug(f'{partition} not in {dir}')
                            continue
                        if partition in cris_no_component and 'cris' in dir:
                            logger.debug(f'{partition} not in {dir}')
                            continue

                        if area_or_energy == 'area':
                            thing = dict_[1]
                            # logger.info(thing)
                            total = cost_layer_dict['overall'][area_or_energy]['onchip']
                            if 'uSystolic' in dir: 
                                total -= cost_layer_dict['overall'][area_or_energy]['sram']
                            for l in level:
                                thing = thing[l]
                            net_dict[cost] = thing
                            net_dict[cost + '/percent'] = thing / total
                            net_dict[cost + '/norm_percent'] = net_dict[cost + '/percent'] * net_dict[f'{area_or_energy}/onchip/norm']
                            if 'uSystolic' in dir: 
                                logger.debug(cost,thing,total,net_dict[f'{area_or_energy}/onchip/norm'],net_dict[cost + '/norm_percent'])
                            if 'systolic' in dir: 
                                cost_cg = sys_bmap[partition] + f'/{area_or_energy}'
                            elif 'uSystolic' in dir: 
                                if usys_bmap[partition] != '':
                                    cost_cg = usys_bmap[partition] + f'/{area_or_energy}'
                                else: assert False
                            elif 'carat' in dir: 
                                cost_cg = carat_bmap[partition] + f'/{area_or_energy}'
                            elif 'cris' in dir:
                                cost_cg = cris_bmap[partition] + f'/{area_or_energy}'
                            else: assert False, logger.debug(dir)
                            # logger.debug(cost,cost_cg);exit()
                            if cost_cg not in net_dict.keys(): 
                                net_dict[cost_cg] = 0
                                logger.warning(f'cost_cg {cost_cg} not in net_dict.keys()')
                            net_dict[cost_cg] += thing
                            net_dict[cost_cg + '/percent'] = net_dict[cost_cg] / total
                            net_dict[cost_cg + '/norm_percent'] = net_dict[cost_cg + '/percent'] * net_dict[f'{area_or_energy}/onchip/norm']
                            # if 'uSystolic' in dir: logger.debug(net_dict);exit()
                        
                        elif area_or_energy == 'energy':
                            net_dict[cost] = 0
                            for dict_ in cost_layer_dict.items():
                                if dict_[0] in ['technology', 'frequency', 'overall']:
                                    continue
                                else: 
                                    thing = dict_[1]
                                    for l in level:
                                        thing = thing[l]
                                    net_dict[cost] += thing
                            tot = net_dict[f'{area_or_energy}/onchip/total']
                            net_dict[cost + '/percent'] = net_dict[cost] / tot
                            net_dict[cost + '/norm_percent'] = net_dict[cost + '/percent'] * net_dict[f'{area_or_energy}/onchip/total/norm']
                            
                            if 'systolic' in dir: 
                                cost_cg = sys_bmap[partition] + f'/{area_or_energy}'
                            elif 'uSystolic' in dir: 
                                if usys_bmap[partition] != '':
                                    cost_cg = usys_bmap[partition] + f'/{area_or_energy}'
                                else:
                                    logger.debug('gothere!!!')
                                    continue
                            elif 'carat' in dir: 
                                cost_cg = carat_bmap[partition] + f'/{area_or_energy}'
                            elif 'cris' in dir:
                                cost_cg = cris_bmap[partition] + f'/{area_or_energy}'
                            else: assert False, logger.debug(dir)

                            if area_or_energy == 'energy' or area_or_energy == 'power':
                                cost_cg += '/total'
                            logger.debug(cost_cg, cost)

                            if cost_cg not in net_dict.keys(): 
                                net_dict[cost_cg] = 0
                                logger.warning(f'cost_cg {cost_cg} not in net_dict.keys()')
                            # logger.debug(net_dict);exit()

                            net_dict[cost_cg] += net_dict[cost]
                            net_dict[cost_cg + '/percent'] = net_dict[cost_cg] / net_dict[f'{area_or_energy}/onchip/total']
                            net_dict[cost_cg + '/norm_percent'] = net_dict[cost_cg + '/percent'] * net_dict[f'{area_or_energy}/onchip/total/norm']
            
                        elif area_or_energy == 'power':
                            net_dict[cost] = 0
                            for dict_ in cost_layer_dict.items():
                                if dict_[0] in ['technology', 'frequency', 'overall']:
                                    continue
                                else: 
                                    thing = dict_[1]
                                    for l in level:
                                        thing = thing[l]
                                    net_dict[cost] += thing
                            tot = net_dict[f'{area_or_energy}/onchip/total']
                            net_dict[cost + '/percent'] = net_dict[cost] / tot
                            net_dict[cost + '/norm_percent'] = net_dict[cost + '/percent'] * net_dict[f'{area_or_energy}/onchip/total/norm']
                            
                            if 'systolic' in dir: 
                                cost_cg = sys_bmap[partition] + f'/{area_or_energy}'
                            elif 'uSystolic' in dir: 
                                if usys_bmap[partition] != '':
                                    cost_cg = usys_bmap[partition] + f'/{area_or_energy}'
                                else:
                                    logger.debug('gothere!!!')
                                    continue
                            elif 'carat' in dir: 
                                cost_cg = carat_bmap[partition] + f'/{area_or_energy}'
                            elif 'cris' in dir:
                                cost_cg = cris_bmap[partition] + f'/{area_or_energy}'
                            else: assert False, logger.debug(dir)

                            cost_cg += '/total'
                            logger.debug(cost_cg, cost)
                            # logger.debug(net_dict);exit()

                            net_dict[cost_cg] += net_dict[cost]
                            net_dict[cost_cg + '/percent'] = net_dict[cost_cg] / net_dict[f'{area_or_energy}/onchip/total']
                            net_dict[cost_cg + '/norm_percent'] = net_dict[cost_cg + '/percent'] * net_dict[f'{area_or_energy}/onchip/total/norm']
            
                if len(other) != 0:
                    # process other metrics
                    for m in other:
                        if m == 'perf_norm':
                            net_dict[m] = net_dict['flops per sec'] / net_dict[PEAK] * 100

                        if m == 'area_total': 
                            net_dict[m] = net_dict['area/onchip'] + net_dict['area/dram']

                        # NOTE: efficiency is in terms of onchip cost
                        if m == 'throughput/area': 
                            net_dict[m] = net_dict['flops per sec'] / \
                                (net_dict['area/onchip'])
                        if m == 'throughput/power': 
                            net_dict[m] = net_dict['flops per sec'] / net_dict['power/onchip/total']
                        if m == 'throughput/energy': 
                            net_dict[m] = net_dict['flops per sec'] / net_dict['energy/onchip/total']
                        
                        if m == 'flops/area': 
                            net_dict[m] = net_dict['flops'] / \
                                (net_dict['area/onchip'])
                        if m == 'flops/power': 
                            net_dict[m] = net_dict['flops'] / net_dict['power/onchip/total']
                        if m == 'flops/energy': 
                            net_dict[m] = net_dict['flops'] / net_dict['energy/onchip/total']

                        net_dict[m + '/norm'] = net_dict[m] / ref_arch[net_name][m]
                        
    return run_set

def plot_format_func(metric, compond_group_list, group2_list, run_set, dir_name,postfix):
    if metric in yml_metric_map.keys():
        metric_in_yml_list = yml_metric_map[metric]
    else:
        assert False
    if type(metric_in_yml_list) == list: pass
    else:
        metric_in_yml_list = [metric_in_yml_list]
    logger.debug(metric_in_yml_list)
    if len(metric_in_yml_list) > 1: use_hatch = True
    else: use_hatch = False

    if metric in legend_map.get(dir_name, default_map): size_tuple_local = size_tuple
    else: size_tuple_local = (fig_w, fig_h_short)
    fig, ax = plt.subplots(figsize=size_tuple_local, dpi=my_dpi)
    use_str = np.arange(len(compond_group_list))
    use2_str = np.arange(len(group2_list))

    prev_list = dict()
    color_bar = []
    lab_arch_arr = []
    for t in np.arange(len(metric_in_yml_list)):
        metric_in_yml = metric_in_yml_list[t]
        logger.debug(metric_in_yml)
        plot_arr = []
        name_arr = []
        for group_name in compond_group_list:
            group_dict = run_set.get(group_name)
            arch_arr = []
            name_arr_ = []
            for arch_item in group_dict.items():
                arch_name = arch_item[0]
                arch_dict = arch_item[1]
                # logger.debug('arch name=',arch_name)
                net_arr = []
                for net_item in arch_dict.items():
                    net_name = net_item[0]
                    net_dict = net_item[1]
                    perf = net_dict.get(metric_in_yml)
                    # logger.debug(f' net name={net_name}, perf={perf}')
                    net_arr.append(perf)
                logger.debug(net_arr)
                if net_arr == [None]*len(net_arr): gm = 0
                else: gm = gmean(net_arr)
                arch_arr.append(gm)
                name_arr_.append(arch_name)
            # logger.debug(f'{group_name}, {arch_arr}')
            plot_arr.append(arch_arr)
            name_arr.append(name_arr_)
        width=1.0/(len(plot_arr[0])+1)

        logger.debug(plot_arr, name_arr)
        color_bar_arr = []
        lab_arch_arr = []
        # for i in np.arange(len(plot_arr[0])):
        #     pass
        #     arr_to_plot = []
        #     arr_name = []
        #     for j in np.arange(len(plot_arr)):
        #         arr_to_plot.append(plot_arr[j][i])
        #         arr_name.append(name_arr[j][i])
        #     logger.debug(arr_to_plot)
        #     logger.debug(arr_name)
            
        #     color_bar_ = plt.bar(use_str+width*i, np.array(arr_to_plot), width, color=color_ , \
        #         alpha=op,\
        #         bottom=prev)

        #     plt.bar(use_str+width*i, np.array(arr_to_plot), width, color='none',edgecolor = 'k', linewidth=0.3, bottom=prev)

        #     if prev == None:
        #         prev_list[(j,i)] = arr_to_plot
        #     else:
        #         prev_arr = np.array(prev)
        #         prev_arr += np.array(arr_to_plot)
        #         prev_list[(j,i)] = list(prev_arr)
            
        #     color_bar_arr.append(color_bar_[0])
        #     lab_arch_arr.append(lab_arch)

    
    size_vec = (0,1.02,1,0.1)
    

    if metric in legend_map.get(dir_name, default_map):
        logger.debug('legend')
        ax.legend(
            # red_hatch, lab_hatch_arr, \
            color_bar_arr, lab_arch_arr, \
            bbox_to_anchor=size_vec, loc="lower left",\
            mode="expand", borderaxespad=0, ncol=5)
    else: logger.info(f'no legend for {dir_name} {metric}')

    display_group_list = [c.replace('carat', 'Carat') for c in compond_group_list]
    display_group_list = [c.replace('cris', 'RIS') for c in compond_group_list]
    display_group_list = [c.replace('systolic', 'SA') for c in display_group_list]
    display_group_list = [c.replace('_edge', '') for c in display_group_list]
    display_group_list = [c.replace('_128x16', '') for c in display_group_list]
    display_group_list = [c.replace('_256x16', '') for c in display_group_list]
    display_group_list = [c.replace('_int8', '') for c in display_group_list]
    display_group_list = [c.replace('_bf16', '') for c in display_group_list]
    display_group_list = [c.replace('_fp8', '') for c in display_group_list]
    tick = use_str+width*float(len(plot_arr[0]))/2
    ax.set_xticks(tick)
    ax.set_xticklabels(display_group_list, rotation=0)

    ylab = axis_map[metric]
    if '(flop/s/nJ)' in ylab: ylab = 'Norm. energy\nefficiency\n(flop/s/nJ)'
    ax.set_ylabel(ylab)
    os.makedirs('plot/mlperf', exist_ok=True)
    metric_name = metric.replace('/', '_')
    metric_name = metric_name.replace(' ', '_')

    if dir_name == None:
        fig_name = f'plot/mlperf/{metric_name}'
    else:
        os.makedirs(f'plot/mlperf/{dir_name}', exist_ok=True)
        
        fig_name = f'plot/mlperf/{dir_name}/{metric_name}'

    if postfix:fig_name += f'_{postfix}'
    fig_name += '.pdf'

    fig.tight_layout()
    plt.savefig(fig_name, bbox_inches='tight', dpi=my_dpi, pad_inches=0.02)
    logger.debug(bcolors.OKGREEN + f'Saved fig as {fig_name}' + bcolors.ENDC)

def plot_format_func_line(metric, compond_group_list, group2_list, run_set, dir_name, overlay,postfix):
    if metric in yml_metric_map.keys():
        metric_in_yml_list = yml_metric_map[metric]
    else:
        assert False
    if type(metric_in_yml_list) == list: pass
    else:
        metric_in_yml_list = [metric_in_yml_list]
    logger.debug(metric_in_yml_list)

    if metric in legend_map.get(dir_name, default_map): size_tuple_local = (3.3115, 1.1)
    else: size_tuple_local = (3.3115, 0.9)

    if size_map.get(dir_name, {}).get(metric): 
        size_tuple_local = size_map[dir_name][metric]

    fig, ax = plt.subplots(figsize=size_tuple_local, dpi=my_dpi)

    for t in np.arange(len(metric_in_yml_list)):
        metric_in_yml = metric_in_yml_list[t]
        logger.debug(metric_in_yml)
        plot_arr = []
        name_arr = []
        for group_name in compond_group_list:
            logger.debug(group_name)
            group_dict = run_set.get(group_name)
            arch_arr = []
            name_arr_ = []
            for arch_item in group_dict.items():
                arch_name = arch_item[0]
                arch_dict = arch_item[1]
                # logger.debug('arch name=',arch_name)
                net_arr = []
                for net_item in arch_dict.items():
                    net_name = net_item[0]
                    net_dict = net_item[1]
                    perf = net_dict.get(metric_in_yml)
                    # logger.debug(metric_in_yml);exit()
                    logger.debug(f' net name={net_name}, perf={perf}')
                    net_arr.append(perf)
                    if 'area' in metric_in_yml:break
                if 'area' in metric_in_yml: gm=net_arr[0]
                else: gm = gmean(net_arr)
                if overlay: arch_arr.append(net_arr)
                else: arch_arr.append(gm)
                name_arr_.append(arch_name)
            # logger.debug(f'{group_name}, {arch_arr}')
            plot_arr.append(arch_arr)
            name_arr.append(name_arr_)
        if not overlay:
            num_datapoint = 3
            index = 0
            for parr in plot_arr:
                if len(parr) == num_datapoint:
                    logger.debug(f"[{index}] all five: {parr}")
                    pass
                # elif len(parr) == 3: 
                #     logger.debug(f"[{index}] only 3: {parr}")
                #     parr.insert(1, None)
                #     parr.insert(3, None)
                #     plot_arr[index] = parr
                else: assert False
                index += 1

            assert(len(plot_arr[0]) == num_datapoint), logger.debug(len(plot_arr))
            color_arr = [bcolors.gray, bcolors.cactus, bcolors.orange]
            label_arr = ['bSA', 'RIS', 'Carat']
            for num in range(len(plot_arr)):
                arr = plot_arr[num]
                arr = np.array(arr).astype(np.double)
                mask = np.isfinite(arr)
                # arr = arr.reshape(3, int(len(arr)/3)).tolist()
                logger.debug(arr)
                plt.plot(
                    np.arange(len(arr))[mask], arr[mask],
                    color=color_arr[num],
                    alpha=1,
                    marker=marker_arr[0],
                    markersize=3,
                    linewidth=1.5,
                    linestyle=ls_arr[0],
                    label=label_arr[num],
                )
        else:
            assert False

    size_vec = (0,1.02,1,0.1)
    loc = "lower left",

    if legend_vec_map.get(dir_name, {}).get(metric):
        size_vec = legend_vec_map[dir_name][metric][0]
        loc = legend_vec_map[dir_name][metric][1]
        logger.info('relocating legend wedget')
    
    if metric in legend_map.get(dir_name, default_map):
        ax.legend(
            bbox_to_anchor=size_vec, loc=loc,\
            mode="expand", borderaxespad=0, ncol=1)
    else: logger.info(f'no legend for {dir_name} {metric}')

    display_group_list = ['4/16/32', '8/64/128', '16/256/512']
    ax.set_xticks(range(len(display_group_list)))
    ax.set_xticklabels(display_group_list, rotation=0)

    # display yticks based on y_ticks_map in util
    y_ticks_dict = y_ticks_map.get(dir_name, default_map)
    yt = y_ticks_dict.get(metric, default_map)
    if y_ticks_dict != default_map and \
        yt != default_map:
        ax.set_yticks(yt)
        ax.set_yticklabels(yt)

    ylab = axis_map[metric]
    if '(flop/s/nJ)' in ylab: ylab = 'Norm. energy\nefficiency\n(flop/s/nJ)'
    ax.set_ylabel(ylab)
    # ax.set_xlabel('Array Shape')
    os.makedirs('plot/mlperf', exist_ok=True)
    metric_name = metric.replace('/', '_')
    metric_name = metric_name.replace(' ', '_')

    if dir_name == None:
        fig_name = f'plot/mlperf/{metric_name}'
    else:
        os.makedirs(f'plot/mlperf/{dir_name}', exist_ok=True)
        
        fig_name = f'plot/mlperf/{dir_name}/{metric_name}'

    if postfix:fig_name += f'_{postfix}'
    fig_name += '.pdf'

    fig.tight_layout()
    plt.savefig(fig_name, bbox_inches='tight', dpi=my_dpi, pad_inches=0.02)
    # plt.savefig("try.pdf")
    logger.debug(bcolors.OKGREEN + f'Saved fig as {fig_name}' + bcolors.ENDC)

def plot_stackedbar_area(metric, compond_group_list, group2_list, run_set, dir_name, overlay,postfix):
    if metric in yml_metric_map.keys():
        metric_in_yml_list = yml_metric_map[metric]
    else:
        assert False
    if type(metric_in_yml_list) == list: pass
    else:
        metric_in_yml_list = [metric_in_yml_list]
    logger.debug(metric_in_yml_list)

    if metric in legend_map.get(dir_name, default_map): size_tuple_local = (3.3115, 1.2)
    else: size_tuple_local = (3.3115, 0.95)

    ylim = (clip_map.get(dir_name, default_map)).get(metric, 100)
    size_tuple_local = (3.3115, 1.3)
    if size_map.get(dir_name, {}).get(metric): 
        size_tuple_local = size_map[dir_name][metric]
        logger.info(f'customized size {size_tuple_local}')

    fig, ax = plt.subplots(figsize=size_tuple_local, dpi=my_dpi)
    
    prev_list = dict()

    for t in np.arange(len(metric_in_yml_list)):
        metric_in_yml = metric_in_yml_list[t]
        logger.debug(metric_in_yml)
        plot_arr = []
        name_arr = []
        for group_name in compond_group_list:
            # logger.debug(group_name)
            group_dict = run_set.get(group_name)
            arch_arr = []
            name_arr_ = []
            for arch_item in group_dict.items():
                arch_name = arch_item[0]
                arch_dict = arch_item[1]
                # logger.debug('arch name=',arch_name)
                net_arr = []
                for net_item in arch_dict.items():
                    net_name = net_item[0]
                    net_dict = net_item[1]
                    perf = net_dict.get(metric_in_yml)
                    # logger.debug(metric_in_yml);exit()
                    # logger.debug(f' net name={net_name}, perf={perf}')
                    net_arr.append(perf)
                    if 'area' in metric_in_yml:break
                if 'area' in metric_in_yml: gm=net_arr[0]
                else: gm = gmean(net_arr)
                if overlay: arch_arr.append(net_arr)
                else: arch_arr.append(gm)
                name_arr_.append(arch_name)
            # logger.debug(f'{group_name}, {arch_arr}')
            plot_arr.append(arch_arr)
            name_arr.append(name_arr_)
        if not overlay:
            num_datapoint = 3
            index = 0
            for parr in plot_arr:
                if len(parr) == num_datapoint:
                    logger.debug(f"[{index}] all five: {parr}")
                    pass
                # elif len(parr) == 3: 
                #     logger.debug(f"[{index}] only 3: {parr}")
                #     parr.insert(1, None)
                #     parr.insert(3, None)
                #     plot_arr[index] = parr
                else: assert False
                index += 1

            assert(len(plot_arr[0]) == num_datapoint), logger.debug(len(plot_arr))
            for num in range(len(plot_arr)):
                arr = plot_arr[num]
                if arr == [None, None, None, None, None]: arr = [0,0,0,0,0]
                else:
                    # any None replace with 0
                    for i in range(len(arr)):
                        if arr[i] == None: arr[i] = 0

                width=1.0/(len(plot_arr[0])+1)
                prev = prev_list.get(num,np.array([0]*len(arr)))
                color_arr = [bcolors.gray, bcolors.cactus, bcolors.orange]
                label_arr = ['bSA', 'RIS', 'Carat']
                step=1.0/(len(metric_in_yml_list)+1)

                logger.info(f"x={np.arange(len(arr))+width*num},y={np.array(arr)}")

                plt.bar(
                    np.arange(len(arr))+width*num, np.array(arr), 
                    width, 
                    color=color_arr[num],
                    alpha=1-t*step,
                    bottom=prev,
                    label=label_arr[num]+'/'+ metric_in_yml.split('/')[0],
                )

                for j in np.arange(len(arr)):
                    if arr[j]+prev[j] >= ylim: 
                        logger.debug(f'annotate {arr[j]+prev[j]}x')
                        ax.annotate('%.1f' % (arr[j]+prev[j])+'x', 
                            xy=(np.arange(len(arr))[j]+width*num-0.3*width, ylim+0.07), 
                            textcoords='data', 
                            rotation=30, 
                            color='grey',
                            fontsize=4,
                            annotation_clip=False)

                if (prev==np.array([0]*len(arr))).all():
                    prev_list[num] = np.array(arr)
                else:
                    prev_arr = np.array(prev)
                    prev_arr += np.array(arr)
                    prev_list[num] = list(prev_arr)
        else:
            assert False

    size_vec = (0,1.02,1,0.1)
    loc='lower left'
    ncol=3
    if legend_vec_map.get(dir_name, {}).get(metric):
        size_vec = legend_vec_map[dir_name][metric][0]
        loc = legend_vec_map[dir_name][metric][1]
        ncol = legend_vec_map[dir_name][metric][2]
        logger.info('relocating legend wedget')
    if metric in legend_map.get(dir_name, default_map):
        ax.legend(
            bbox_to_anchor=size_vec, loc=loc,\
            mode="expand", borderaxespad=0, ncol=ncol)
    else: logger.info(f'no legend for {dir_name} {metric}')

    display_group_list = ['4/16/32', '8/64/128', '16/256/512']
    ax.set_xticks(np.arange(len(display_group_list))+width*float(num)/2)
    ax.set_xticklabels(display_group_list, rotation=0)

    if ylim != 100:
        logger.debug(f"clipped at {ylim}")
        ax.set_ylim(top=ylim)

    # display yticks based on y_ticks_map in util
    y_ticks_dict = y_ticks_map.get(dir_name, default_map)
    yt = y_ticks_dict.get(metric, default_map)
    if y_ticks_dict != default_map and \
        yt != default_map:
        ax.set_yticks(yt)
        ax.set_yticklabels(yt)

    ylab = axis_map[metric]
    if '(flop/s/nJ)' in ylab: ylab = 'Norm. energy\nefficiency\n(flop/s/nJ)'
    ax.set_ylabel(ylab)
    # ax.set_xlabel('Array Shape')
    os.makedirs('plot/mlperf', exist_ok=True)
    metric_name = metric.replace('/', '_')
    metric_name = metric_name.replace(' ', '_')

    if dir_name == None:
        fig_name = f'plot/mlperf/{metric_name}'
    else:
        os.makedirs(f'plot/mlperf/{dir_name}', exist_ok=True)
        
        fig_name = f'plot/mlperf/{dir_name}/{metric_name}'

    if postfix:fig_name += f'_{postfix}'
    fig_name += '.pdf'

    fig.tight_layout()
    plt.savefig(fig_name, bbox_inches='tight', dpi=my_dpi, pad_inches=0.02)
    logger.debug(bcolors.OKGREEN + f'Saved fig as {fig_name}' + bcolors.ENDC)

def my_autopct(pct):
    return ('%1.0f%%' % pct) if pct > 3 else ''

def my_autopct_makesmall(pct):
    return ('%1.0f%%' % pct) if pct > 5 else ''

def plot_pie_area(metric, compond_group_list, group2_list, run_set, dir_name, overlay,postfix):
    # logger.debug(group2_list)
    if metric in yml_metric_map.keys():
        metric_in_yml_list = yml_metric_map[metric]
    else:
        assert False
    if type(metric_in_yml_list) == list: pass
    else:
        metric_in_yml_list = [metric_in_yml_list]
    logger.debug(metric_in_yml_list)

    if size_map.get(dir_name, {}).get(metric): 
        size_tuple_local = size_map[dir_name][metric]
    else:
        if metric in legend_map.get(dir_name, default_map): size_tuple_local = (3.3115, 1.2)
        else: size_tuple_local = (3.3115, 0.95)
    # logger.debug(size_tuple_local)
    # size_tuple_local = (3.3115, 1.3)

    
    prev_list = dict()

    plot_all_component = []
    name_all_component=[]
    for t in np.arange(len(metric_in_yml_list)):
        metric_in_yml = metric_in_yml_list[t]
        logger.debug(metric_in_yml)
        plot_arr = []
        name_arr = []
        for group_name in compond_group_list:
            # logger.debug(group_name)
            group_dict = run_set.get(group_name)
            arch_arr = []
            name_arr_ = []
            for arch_item in group_dict.items():
                arch_name = arch_item[0]
                arch_dict = arch_item[1]
                # logger.debug('arch name=',arch_name)
                net_arr = []
                for net_item in arch_dict.items():
                    net_name = net_item[0]
                    net_dict = net_item[1]
                    perf = net_dict.get(metric_in_yml)
                    # logger.debug(metric_in_yml);exit()
                    # logger.debug(f' net name={net_name}, perf={perf}')
                    net_arr.append(perf)
                    if 'area' in metric_in_yml:break
                if 'area' in metric_in_yml: gm=net_arr[0]
                else: gm = gmean(net_arr)
                if overlay: arch_arr.append(net_arr)
                else: arch_arr.append(gm)
                name_arr_.append(arch_name)
            # logger.debug(f'{group_name}, {arch_arr}')
            if arch_arr == [None, None, None]: arch_arr = [0,0,0]
            plot_arr.append(arch_arr)
            name_arr.append(name_arr_)
        assert not overlay
        logger.debug(len(plot_arr))
        # if plot_arr == [None, None, None]: plot_arr = [0,0,0]
        plot_arr = [[i if i!=None else 0 for i in subarr] for subarr in plot_arr]
        plot_all_component.append(plot_arr)
        name_all_component.append(name_arr)
        logger.debug(f'adding {plot_arr},{name_arr}')
    # logger.debug(f'[component,archtype,shape]=[{len(plot_all_component)},{len(plot_all_component[0])},{len(plot_all_component[0][0])}]')
   
    num_designs = 0
    for t in plot_all_component[0]:
        num_designs+=len(t)
        # logger.debug(len(t))
    logger.debug(f'num_designs={num_designs}')
    ind = 0
    for (pac,nac) in zip(plot_all_component,name_all_component):
        pac = sum(pac, [])
        nac = sum(nac, [])
        # logger.debug(pac)
        plot_all_component[ind] = pac
        name_all_component[ind] = nac
        ind += 1
    logger.debug(f'plot_all_component={plot_all_component}')

    plot_all_component_reshaped = np.array(plot_all_component).reshape(num_designs,
                                                                       len(plot_all_component[0])).transpose()
    name_all_component_reshaped = np.array(name_all_component).reshape(num_designs,
                                                                       len(name_all_component[0])).transpose()
    logger.debug(f'[archtype*shape,component]=[{len(plot_all_component_reshaped)},{len(plot_all_component_reshaped[0])}]')
    
    fig, axes = plt.subplots(figsize=size_tuple_local, dpi=my_dpi, nrows=1, ncols=len(plot_all_component_reshaped))
    
    from plot_util import pie_color_map
    prev_sum_all_component=sum(plot_all_component_reshaped[0])
    for i in range(len(plot_all_component_reshaped)):
        logger.debug(f'adding {i}th pie')
        lab = [metric_in_yml.split('/')[0] for metric_in_yml in metric_in_yml_list]
        colors = [pie_color_map[i] for i in lab]
        logger.debug(lab)

        sum_all_component = sum(plot_all_component_reshaped[i])
        # radius = sum_all_component/prev_sum_all_component
        # logger.debug(radius)
        
        patches= axes[i].pie(plot_all_component_reshaped[i], 
                    autopct=my_autopct_makesmall if len(plot_all_component_reshaped)>3 else my_autopct,
                    colors=colors,
                    pctdistance=1.1 if len(plot_all_component_reshaped)>3 else 0.6,
                    # radius=radius,
                    # labels=lab, 
        )
        name_short = trim_item(name_all_component_reshaped[i][0], [True, True, True, False, False], False)
        logger.debug(name_short)
        axes[i].set_xlabel(f'{name_short}')
        axes[i].xaxis.labelpad = -3

        secax = axes[i].secondary_xaxis('top')
        sum_ = sum(plot_all_component_reshaped[i])
        if 'area' not in metric_in_yml:  
            # if sum(plot_2d[i])> pow(10,9):
            #     secax.set_xlabel(f'{round(sum(plot_2d[i])/pow(10,9),1)}J', fontsize=6)
            # if sum_ > pow(10,6):
            #     secax.set_xlabel(f'{round(sum_/pow(10,6),1)}({round(avg_/pow(10,6),1)}) mJ', fontsize=5)
            if sum_ > pow(10,3):
                secax.set_xlabel(f'{round(sum_/pow(10,3),2)} uJ', fontsize=5, color='grey')
            else:
                secax.set_xlabel(f'{round(sum_,2)} nJ', fontsize=5, color='grey')
        else: 
            # secax.set_xlabel(f'{round(sum_*pow(10,6),1)} '+"$\mathregular{um^{2}}$", fontsize=5, color='grey')
            secax.set_xlabel(f'{round(sum_,2)} '+"$\mathregular{mm^{2}}$", fontsize=5, color='grey')
        secax.xaxis.labelpad = -2
        import matplotlib.ticker as ticker
        secax.xaxis.set_major_locator(ticker.NullLocator())
        secax.yaxis.set_major_locator(ticker.NullLocator())
        secax.tick_params('both', length=0, width=0, which='major')
        secax.tick_params('both', length=0, width=0, which='minor')
        for side in ['top','right','bottom','left']:
            secax.spines[side].set_visible(False) 
    
    # if metric in legend_map.get(dir_name, default_map):
    if metric in legend_map.get(dir_name, default_map):
        if len(plot_all_component_reshaped)>3:
                axes[0].legend(patches, labels=lab, 
                        bbox_to_anchor=(0, 1.1, len(plot_all_component_reshaped),0.1), 
                        loc='lower left',
                # mode="expand", 
                # borderaxespad=0, 
                ncol=len(lab),
                # fontsize=5,
            )
        else:
            axes[-1].legend(patches, labels=lab, 
                        bbox_to_anchor=(1, 0, 0.1, 1), 
                        loc='lower left',
                # mode="expand", 
                # borderaxespad=0, 
                ncol=1,
                # fontsize=5,
            )
    else: 
        logger.info(f'no legend for {dir_name} {metric}')
        
    # if metric in legend_map.get(dir_name, default_map):
        
    # else: logger.info(f'no legend for {dir_name} {metric}')

    # display_group_list = ['4/16/32', '8/32/128', '16/64/512', ]
    # if size_tuple_local[0] < 3:
    #     logger.info('wrap long name')
    #     display_group_list = ['16 or\n4x4', '64 or\n8x8', '256 or\n16x164', ]
    # ax.set_xticks(np.arange(len(display_group_list))+width*float(num)/2)
    # ax.set_xticklabels(display_group_list, rotation=0)

    # # display yticks based on y_ticks_map in util
    # y_ticks_dict = y_ticks_map.get(dir_name, default_map)
    # yt = y_ticks_dict.get(metric, default_map)
    # if y_ticks_dict != default_map and \
    #     yt != default_map:
    #     ax.set_yticks(yt)
    #     ax.set_yticklabels(yt)

    # ylab = axis_map[metric]
    # if '(flop/s/nJ)' in ylab: ylab = 'Norm. energy\nefficiency\n(flop/s/nJ)'
    # ax.set_ylabel(ylab)
    # # ax.set_xlabel('Array Shape')
    os.makedirs('plot/mlperf', exist_ok=True)
    metric_name = metric.replace('/', '_')
    metric_name = metric_name.replace(' ', '_')

    if dir_name == None:
        fig_name = f'plot/mlperf/{metric_name}'
    else:
        os.makedirs(f'plot/mlperf/{dir_name}', exist_ok=True)
        
        fig_name = f'plot/mlperf/{dir_name}/{metric_name}'
    if postfix:fig_name += f'_{postfix}'
    fig_name += '.pdf'

    fig.tight_layout()
    plt.savefig(fig_name, bbox_inches='tight', dpi=my_dpi, pad_inches=0.02)
    logger.debug(bcolors.OKGREEN + f'Saved fig as {fig_name}' + bcolors.ENDC)


def plot_pie_energy(metric, compond_group_list, group2_list, run_set, dir_name, overlay,postfix):
    # logger.debug(group2_list)
    if metric in yml_metric_map.keys():
        metric_in_yml_list = yml_metric_map[metric]
    else:
        assert False
    if type(metric_in_yml_list) == list: pass
    else:
        metric_in_yml_list = [metric_in_yml_list]
    logger.debug(metric_in_yml_list)

    if size_map.get(dir_name, {}).get(metric): 
        size_tuple_local = size_map[dir_name][metric]
    else:
        if metric in legend_map.get(dir_name, default_map): size_tuple_local = (3.3115, 1.2)
        else: size_tuple_local = (3.3115, 0.95)
    # logger.debug(size_tuple_local)
    # size_tuple_local = (3.3115, 1.3)

    
    prev_list = dict()

    plot_all_component = []
    name_all_component=[]
    for t in np.arange(len(metric_in_yml_list)):
        metric_in_yml = metric_in_yml_list[t]
        logger.info(metric_in_yml)
        plot_arr = []
        name_arr = []
        for group_name in compond_group_list:
            logger.info(f"group name={group_name}")
            group_dict = run_set.get(group_name)
            arch_arr = []
            name_arr_ = []
            for arch_item in group_dict.items():
                arch_name = arch_item[0]
                arch_dict = arch_item[1]
                # logger.debug(f'arch name={arch_name}')
                net_arr = []
                for net_item in arch_dict.items():
                    net_name = net_item[0]
                    net_dict = net_item[1]
                    perf = net_dict.get(metric_in_yml)
                    # logger.debug(metric_in_yml);exit()
                    # logger.debug(net_dict)
                    # logger.debug(f'net name={net_name}, perf={perf}')
                    net_arr.append(perf)
                if net_arr==[None, None, None, None, None, None]:
                    net_arr = [0,0,0,0,0,0]
                # logger.debug(net_arr)
                logger.debug(net_arr)
                gm = gmean(net_arr)
                if overlay: arch_arr.append(net_arr)
                else: arch_arr.append(gm)
                name_arr_.append(arch_name)
            # logger.debug(f'{group_name}, {arch_arr}')
            if arch_arr == [None, None, None]: arch_arr = [0,0,0]
            plot_arr.append(arch_arr)
            name_arr.append(name_arr_)
        logger.debug(len(plot_arr))
        # if plot_arr == [None, None, None]: plot_arr = [0,0,0]
        plot_arr = [[i if i!=None else 0 for i in subarr] for subarr in plot_arr]
        plot_all_component.append(plot_arr)
        name_all_component.append(name_arr)
        logger.debug(f'adding {plot_arr},{name_arr}')
    # logger.debug(f'[component,archtype,shape]=[{len(plot_all_component)},{len(plot_all_component[0])},{len(plot_all_component[0][0])}]')
   
    num_designs = 0
    for t in plot_all_component[0]:
        num_designs+=len(t)
        # logger.debug(len(t))
    logger.debug(f'num_designs={num_designs}')
    ind = 0
    for (pac,nac) in zip(plot_all_component,name_all_component):
        pac = sum(pac, [])
        nac = sum(nac, [])
        # logger.debug(pac)
        plot_all_component[ind] = pac
        name_all_component[ind] = nac
        ind += 1
    logger.debug(f'plot_all_component={plot_all_component}')

    plot_all_component_reshaped = np.array(plot_all_component).reshape(num_designs,
                                                                       len(plot_all_component[0])).transpose()
    name_all_component_reshaped = np.array(name_all_component).reshape(num_designs,
                                                                       len(name_all_component[0])).transpose()
    logger.debug(f'[archtype*shape,component]=[{len(plot_all_component_reshaped)},{len(plot_all_component_reshaped[0])}]')
    
    fig, axes = plt.subplots(figsize=size_tuple_local, dpi=my_dpi, nrows=1, ncols=len(plot_all_component_reshaped))
    
    from plot_util import pie_color_map
    prev_sum_all_component=sum(plot_all_component_reshaped[0])
    for i in range(len(plot_all_component_reshaped)):
        logger.debug(f'adding {i}th pie')
        lab = [metric_in_yml.split('/')[0] for metric_in_yml in metric_in_yml_list]
        colors = [pie_color_map[i] for i in lab]
        logger.debug(lab)

        sum_all_component = sum(plot_all_component_reshaped[i])
        # radius = sum_all_component/prev_sum_all_component
        # logger.debug(radius)
        
        patches= axes[i].pie(plot_all_component_reshaped[i], 
                    autopct=my_autopct_makesmall if len(plot_all_component_reshaped)>3 else my_autopct,
                    colors=colors,
                    pctdistance=1.1 if len(plot_all_component_reshaped)>3 else 0.6,
                    # radius=radius,
                    # labels=lab, 
        )
        name_short = trim_item(name_all_component_reshaped[i][0], [True, True, True, False, False], False)
        logger.debug(name_short)
        axes[i].set_xlabel(f'{name_short}')
        axes[i].xaxis.labelpad = -3

        secax = axes[i].secondary_xaxis('top')
        sum_ = sum(plot_all_component_reshaped[i])
        if 'area' not in metric_in_yml:  
            # if sum(plot_2d[i])> pow(10,9):
            #     secax.set_xlabel(f'{round(sum(plot_2d[i])/pow(10,9),1)}J', fontsize=6)
            # if sum_ > pow(10,6):
            #     secax.set_xlabel(f'{round(sum_/pow(10,6),1)}({round(avg_/pow(10,6),1)}) mJ', fontsize=5)
            if sum_ > pow(10,3):
                secax.set_xlabel(f'{round(sum_/pow(10,3),2)} uJ', fontsize=5, color='grey')
            else:
                secax.set_xlabel(f'{round(sum_,2)} nJ', fontsize=5, color='grey')
        else: 
            # secax.set_xlabel(f'{round(sum_*pow(10,6),1)} '+"$\mathregular{um^{2}}$", fontsize=5, color='grey')
            secax.set_xlabel(f'{round(sum_,2)} '+"$\mathregular{mm^{2}}$", fontsize=5, color='grey')
        secax.xaxis.labelpad = -2
        import matplotlib.ticker as ticker
        secax.xaxis.set_major_locator(ticker.NullLocator())
        secax.yaxis.set_major_locator(ticker.NullLocator())
        secax.tick_params('both', length=0, width=0, which='major')
        secax.tick_params('both', length=0, width=0, which='minor')
        for side in ['top','right','bottom','left']:
            secax.spines[side].set_visible(False) 
    
    # if metric in legend_map.get(dir_name, default_map):
    if metric in legend_map.get(dir_name, default_map):
        if len(plot_all_component_reshaped)>3:
                axes[0].legend(patches, labels=lab, 
                        bbox_to_anchor=(0, 1, len(plot_all_component_reshaped),0.1), 
                        loc='lower left',
                # mode="expand", 
                # borderaxespad=0, 
                ncol=len(lab),
                # fontsize=5,
            )
        else:
            axes[-1].legend(patches, labels=lab, 
                        bbox_to_anchor=(1, 0, 0.1, 1), 
                        loc='lower left',
                # mode="expand", 
                # borderaxespad=0, 
                ncol=1,
                # fontsize=5,
            )
    else: logger.info(f'no legend for {dir_name} {metric}')
        
    # if metric in legend_map.get(dir_name, default_map):
        
    # else: logger.info(f'no legend for {dir_name} {metric}')

    # display_group_list = ['4/16/32', '8/32/128', '16/64/512', ]
    # if size_tuple_local[0] < 3:
    #     logger.info('wrap long name')
    #     display_group_list = ['16 or\n4x4', '64 or\n8x8', '256 or\n16x164', ]
    # ax.set_xticks(np.arange(len(display_group_list))+width*float(num)/2)
    # ax.set_xticklabels(display_group_list, rotation=0)

    # # display yticks based on y_ticks_map in util
    # y_ticks_dict = y_ticks_map.get(dir_name, default_map)
    # yt = y_ticks_dict.get(metric, default_map)
    # if y_ticks_dict != default_map and \
    #     yt != default_map:
    #     ax.set_yticks(yt)
    #     ax.set_yticklabels(yt)

    # ylab = axis_map[metric]
    # if '(flop/s/nJ)' in ylab: ylab = 'Norm. energy\nefficiency\n(flop/s/nJ)'
    # ax.set_ylabel(ylab)
    # # ax.set_xlabel('Array Shape')
    os.makedirs('plot/mlperf', exist_ok=True)
    metric_name = metric.replace('/', '_')
    metric_name = metric_name.replace(' ', '_')

    if dir_name == None:
        fig_name = f'plot/mlperf/{metric_name}'
    else:
        os.makedirs(f'plot/mlperf/{dir_name}', exist_ok=True)
        
        fig_name = f'plot/mlperf/{dir_name}/{metric_name}'
    if postfix:fig_name += f'_{postfix}'
    fig_name += '.pdf'

    fig.tight_layout()
    plt.savefig(fig_name, bbox_inches='tight', dpi=my_dpi, pad_inches=0.02)
    logger.debug(bcolors.OKGREEN + f'Saved fig as {fig_name}' + bcolors.ENDC)



def plot_stackedbar_energy(metric, compond_group_list, group2_list, run_set, dir_name, overlay,postfix):
    if metric in yml_metric_map.keys():
        metric_in_yml_list = yml_metric_map[metric]
    else:
        assert False
    if type(metric_in_yml_list) == list: pass
    else:
        metric_in_yml_list = [metric_in_yml_list]
    logger.debug(metric_in_yml_list)

    if metric in legend_map.get(dir_name, default_map): size_tuple_local = (3.3115, 1.2)
    else: size_tuple_local = (3.3115, 0.95)

    size_tuple_local = (3.3115, 1.1)
    if size_map.get(dir_name, {}).get(metric): 
        size_tuple_local = size_map[dir_name][metric]
        logger.info(f'customized size {size_tuple_local}')

    fig, ax = plt.subplots(figsize=size_tuple_local, dpi=my_dpi)
    ylim = (clip_map.get(dir_name, default_map)).get(metric, 100)
    prev_list = dict()

    for t in np.arange(len(metric_in_yml_list)):
        metric_in_yml = metric_in_yml_list[t]
        logger.info(metric_in_yml)
        plot_arr = []
        name_arr = []
        for group_name in compond_group_list:
            logger.info(f"group name={group_name}")
            group_dict = run_set.get(group_name)
            arch_arr = []
            name_arr_ = []
            for arch_item in group_dict.items():
                arch_name = arch_item[0]
                arch_dict = arch_item[1]
                # logger.debug(f'arch name={arch_name}')
                net_arr = []
                for net_item in arch_dict.items():
                    net_name = net_item[0]
                    net_dict = net_item[1]
                    perf = net_dict.get(metric_in_yml)
                    # logger.debug(metric_in_yml);exit()
                    # logger.debug(net_dict)
                    # logger.debug(f'net name={net_name}, perf={perf}')
                    net_arr.append(perf)
                if net_arr==[None, None, None, None, None, None]:
                    net_arr = [0,0,0,0,0,0]
                # logger.debug(net_arr)
                logger.debug(net_arr)
                gm = gmean(net_arr)
                if overlay: arch_arr.append(net_arr)
                else: arch_arr.append(gm)
                name_arr_.append(arch_name)
            # logger.debug(f'{group_name}, {arch_arr}')
            plot_arr.append(arch_arr)
            name_arr.append(name_arr_)
        if not overlay:
            num_datapoint = 3
            index = 0
            for parr in plot_arr:
                if len(parr) == num_datapoint:
                    logger.debug(f"[{index}] all 3: {parr}")
                    pass
                # elif len(parr) == 3: 
                #     logger.debug(f"[{index}] only 3: {parr}")
                #     parr.insert(1, None)
                #     parr.insert(3, None)
                #     plot_arr[index] = parr
                else: assert False
                index += 1

            assert(len(plot_arr[0]) == num_datapoint), logger.debug(len(plot_arr))
            color_arr = [bcolors.gray, bcolors.cactus, bcolors.orange]
            label_arr = ['bSA', 'RIS', 'Carat']
            for num in range(len(plot_arr)):
                arr = plot_arr[num]
                if arr == [None]*len(arr): arr = [0]*len(arr)
                else:
                    # any None replace with 0
                    for i in range(len(arr)):
                        if arr[i] == None: arr[i] = 0

                # logger.info(arr)
                width=1.0/(len(plot_arr[0])+1)
                prev = prev_list.get(num,np.array([0]*len(arr)))
                step=1.0/(len(metric_in_yml_list)+1)
                plt.bar(
                    np.arange(len(arr))+width*num, np.array(arr), 
                    width, 
                    color=color_arr[num],
                    alpha=1-t*step,
                    bottom=prev,
                    label=label_arr[num]+'/'+ metric_in_yml.split('/')[0],
                )
                for j in np.arange(len(arr)):
                    if arr[j]+prev[j] >= ylim: 
                        logger.debug(f'annotate {arr[j]+prev[j]}x')
                        ax.annotate('%.1f' % (arr[j]+prev[j])+'x', 
                            xy=(np.arange(len(arr))[j]+width*num-0.3*width, ylim+0.07), 
                            textcoords='data', 
                            rotation=30, 
                            color='grey',
                            fontsize=4,
                            annotation_clip=False)
                

                if (prev==np.array([0]*len(arr))).all():
                    prev_list[num] = np.array(arr)
                else:
                    prev_arr = np.array(prev)
                    prev_arr += np.array(arr)
                    prev_list[num] = list(prev_arr)
        else:
            assert False

    size_vec = (0,1.02,1,0.1)
    if metric in legend_map.get(dir_name, default_map):
        logger.debug('legend')
        ax.legend(
            bbox_to_anchor=size_vec, loc="lower left",\
            mode="expand", borderaxespad=0, ncol=3)
    else: logger.info(f'no legend for {dir_name} {metric}')

    display_group_list = ['4/16/32', '8/64/128', '16/256/512']
    ax.set_xticks(np.arange(len(display_group_list))+width*float(num)/2)
    ax.set_xticklabels(display_group_list, rotation=0)

    # display yticks based on y_ticks_map in util
    y_ticks_dict = y_ticks_map.get(dir_name, default_map)
    yt = y_ticks_dict.get(metric, default_map)
    if y_ticks_dict != default_map and \
        yt != default_map:
        ax.set_yticks(yt)
        ax.set_yticklabels(yt)

    ylab = axis_map[metric]
    if '(flop/s/nJ)' in ylab: ylab = 'Norm. energy\nefficiency\n(flop/s/nJ)'
    ax.set_ylabel(ylab)

    if ylim != 100:
        logger.debug(f"clipped at {ylim}")
        ax.set_ylim(top=ylim)

    # ax.set_xlabel('Array Shape')
    os.makedirs('plot/mlperf', exist_ok=True)
    metric_name = metric.replace('/', '_')
    metric_name = metric_name.replace(' ', '_')

    if dir_name == None:
        fig_name = f'plot/mlperf/{metric_name}'
    else:
        os.makedirs(f'plot/mlperf/{dir_name}', exist_ok=True)
        
        fig_name = f'plot/mlperf/{dir_name}/{metric_name}'

    if postfix:fig_name += f'_{postfix}'
    fig_name += '.pdf'

    fig.tight_layout()
    plt.savefig(fig_name, bbox_inches='tight', dpi=my_dpi, pad_inches=0.02)
    logger.debug(bcolors.OKGREEN + f'Saved fig as {fig_name}' + bcolors.ENDC)


def read_run_set(run_set, metric_str, arch_q_name):
    arch_arr = []
    gm_arr = []
    for (group_name, level1_dict) in run_set.items():
        # logger.debug(bcolors.FAIL+ f'{level1_dict}'+ bcolors.ENDC)
        for (arch_name, level2_dict) in level1_dict.items():
            # logger.debug(bcolors.WARNING+ f'{level2_dict}'+ bcolors.ENDC)
            value_arr = []
            for (net_name,level3_dict) in level2_dict.items():
                if net_name in args.mlperf: 
                    # logger.debug(bcolors.HEADER+ f'{level3_dict}'+ bcolors.ENDC)
                    for (metrics,values) in level3_dict.items():
                        if metrics == metric_str: 
                            # logger.debug(bcolors.OKBLUE+ f'{net_name}-{metrics}:{values}'+ bcolors.ENDC)
                            value_arr.append(values)
            # logger.debug(bcolors.OKCYAN+ f'{arch_name}-{metrics}:{gmean(value_arr)}'+ bcolors.ENDC)
            gm_arr.append(gmean(value_arr))
            arch_arr.append(arch_name)
    
    # logger.debug(bcolors.WARNING+ f'{arch_arr}'+ bcolors.ENDC)
    # logger.debug(bcolors.FAIL+ f'{gm_arr}'+ bcolors.ENDC)
    for n, v in zip(arch_arr, gm_arr):
        if n == arch_q_name: return v
    logger.debug(bcolors.FAIL + f'no such query name: {arch_q_name}' + bcolors.ENDC)
    exit()

def rescale(run_set):
    net_arr = []
    for key,val in run_set.items():
        for key_,val_ in val.items():
            for key__, val__ in val_.items():
                net = key__
                if net not in net_arr: net_arr.append(net)

    found_ref_val = False
    for key,val in run_set.items():
        batch = key
        for key_,val_ in val.items():
            arch = key_
            for key__, val__ in val_.items():
                net = key__
                for key___, val___ in val__.items():
                    metric_name = key___
                    metric_val = val___
                    if '/norm' in metric_name:
                        ind = net_arr.index(net)
                        if ind == 0 and not found_ref_val:
                            ref_val = val__[metric_name.replace('/norm', '')]
                            found_ref_val = True
                        scaled_val = val__[metric_name.replace('/norm', '')] / ref_val
                        run_set[key][key_][key__][metric_name] = scaled_val
                        logger.debug(f'ind={ind}, metric={metric_name}, val={metric_val}, ref={ref_val}, new={scaled_val}, {run_set[key][key_][key__][metric_name]}')
    return run_set

if __name__ == "__main__":
    parser = construct_argparser()
    args = parser.parse_args()

    assert not ("area_breakdown" in args.metric and "area_breakdown_sram_and_rest" in args.metric), "cannot have both area_breakdown and area_breakdown_sram_and_rest"

    logger.debug('*** mlperf: ', args.mlperf)
    logger.debug('*** dir: ', args.dirname)
    os.makedirs('plot/log', exist_ok=True)
    timestr = time.strftime("%Y%m%d-%H%M%S")
    if args.dirname != None:
        os.makedirs(f'plot/log/{args.dirname}', exist_ok=True)
        yml_name = f'./plot/log/{args.dirname}/plot_{timestr}.yml'
    else:
        yml_name = f'./plot/log/plot_{timestr}.yml'

    # === peak perf lambda functions ===
    get_sys_peak_flops_per_sec = lambda arch_dict, cycle:(arch_dict['architecture']['compu']['num_instances'][0] * \
        arch_dict['architecture']['compu']['num_instances'][1] * 2.0) / cycle * \
        arch_dict['architecture']['frequency'] * 10**6
    # get_sys_peak_flops_per_sec = lambda arch_dict, cycle:(arch_dict['architecture']['compu']['num_instances'][0] * \
    #     arch_dict['architecture']['compu']['num_instances'][1] * 2.0 + \
    #     arch_dict['architecture']['compu']['num_instances'][1]) / cycle * \
    #     arch_dict['architecture']['frequency'] * 10**6
    get_tlut_peak_flops_per_sec = lambda arch_dict, chunk_size, data_length: 2 * \
        arch_dict['architecture']['compu']['num_instances'][0] * \
        (float(chunk_size)/data_length) * arch_dict['architecture']['frequency'] * 10**6

    get_cris_scaled_similarity = lambda wl_dict: (1 + (wl_dict['workload'][list(wl_dict['workload'].keys())[0]]['N']-1) * \
        (1-(wl_dict['workload'][list(wl_dict['workload'].keys())[0]]['similarity']))) / \
            (wl_dict['workload'][list(wl_dict['workload'].keys())[0]]['N'])
    get_cris_peak_flops_per_sec = lambda arch_dict, wl_dict, cycle:(arch_dict['architecture']['compu']['num_instances'][0] * \
        arch_dict['architecture']['compu']['num_instances'][1] * 2.0) / cycle * \
        arch_dict['architecture']['frequency'] * 10**6 / get_cris_scaled_similarity(wl_dict)
    # === peak perf lambda functions ===

    logger.debug('*** points: ', args.point)

    if args.use_universal_baseline:
        logger.debug('*** using universal baseline')
    else:
        logger.debug('*** normalizing in terms of group (default)')

    # === construct run set ===
    logger.debug(bcolors.OKCYAN + f'Constructing run set...' + bcolors.ENDC)

    # group_list = ['systolic', \
    #         'uSystolic', \
    #         'carat', \
    #         ]
    group_list = ['systolic', \
            'cris',
            'carat', \
            ]
    if args.carat_only:
        group_list = ['carat']
    for p in args.point:
        if "i1" in p and "carat" in p:
            group_list.append('carat_i1')
            break
    for p in args.point:
        if "idle8" in p and "carat" in p:
            group_list.append('carat_idle8')
            break

    # group2_list = ['sram1','sram2']
    group2_list = ['fp8']
    logger.info(f'*** group by {group_list}')
    compond_group_list = []

    run_set = OrderedDict()
    
    # index 1: format
    for group2 in group2_list:
        for group_ in group_list:
            if 'uSys' in group_: group2 = 'int8'
            else: group2 = 'fp8'
            group_dict = OrderedDict()
            # index2: arch
            for point in args.point:
                if group_ not in point or group2 not in point: 
                    continue
                if group_ == "carat" and "i1" in point:
                    continue
                if group_ == "carat" and "idle8" in point:
                    continue
                arch_name = point.split('/')[-1]
                arch_name_ = point.split('/')[-1] + '_' + point.split('/')[1]
                dir_append = point.replace(arch_name, '')
                myrootdir = f'runs/{dir_append}/'
                logger.debug(myrootdir)
                group = OrderedDict()
                # index3: network
                for net in args.mlperf:
                    reg_name = re.compile(re.escape(arch_name) + '_' + net + r'_c\d{1,3}_n'+str(args.nval))
                    dir_name = next(os.walk(myrootdir))[1]
                    for dirnames in dir_name:
                        if reg_name.match(dirnames):
                            dirnames = f'{dir_append}/{dirnames}'
                            dict_ = OrderedDict()
                            dict_['dir'] = dirnames
                            group[net] = dict_
                group_dict[arch_name_] = group
            compond_name = group_+'_'+group2
            # logger.debug(f"adding compond name to list:{compond_name}",)
            compond_group_list.append(compond_name)
            run_set[compond_name] = group_dict

    # === register needed results ===
    perf_read, cost_read, other, cost_layer_read = gen_read_sets(args.metric)

    # === fill in perf results ===
    logger.debug(bcolors.OKCYAN + f'Reading yml...' + bcolors.ENDC)
    ref_arch_name = args.point[0].split('/')[-1]

    if args.pie:
        sys_bmap = sys_breakdown_map
        usys_bmap = usys_breakdown_map
        carat_bmap = carat_breakdown_map
        cris_bmap = cris_breakdown_map
    else:
        if "area_breakdown_sram_and_rest" in args.metric or "energy_breakdown_sram_and_rest" in args.metric:
            sys_bmap = sys_breakdown_map_sram_and_rest
            usys_bmap = usys_breakdown_map_sram_and_rest
            carat_bmap = carat_breakdown_map_sram_and_rest
            cris_bmap = cris_breakdown_map_sram_and_rest
        elif "area_breakdown" in args.metric or "energy_breakdown" in args.metric or \
            "area_breakdown_nosram" in args.metric or "energy_breakdown_nosram" in args.metric or \
            "area_breakdown_array_level" in args.metric or "energy_breakdown_array_level" in args.metric:
            sys_bmap = sys_breakdown_map
            usys_bmap = usys_breakdown_map
            carat_bmap = carat_breakdown_map
            cris_bmap = cris_breakdown_map
        else:
            logger.warning("check breakdown map, now using sram and rest ver.")
            sys_bmap = sys_breakdown_map_sram_and_rest
            usys_bmap = usys_breakdown_map_sram_and_rest
            carat_bmap = carat_breakdown_map_sram_and_rest
            cris_bmap = cris_breakdown_map_sram_and_rest

    run_set = read_perf_cost(run_set, perf_read, cost_read, other, cost_layer_read, 
                             ref_arch_name, args.use_universal_baseline,
                             sys_bmap, usys_bmap, carat_bmap, cris_bmap)
    
    # if args.overlay:
    # rescale metric/norm
    
    if args.norescale: pass
    else:
        logger.debug(bcolors.WARNING + f'rescale metric/norm' + bcolors.ENDC)
        run_set = rescale(run_set)

    yaml_overwrite(yml_name, run_set) # TODO: remove for speed and space
    
    # output number for table filling
    if args.query and args.qm: 
        for q in args.query:
            for m in args.qm:
                res = read_run_set(run_set, m, q)
                logger.debug(bcolors.HEADER + f'{m}-{q}: {res}' + bcolors.ENDC)

    logger.debug(bcolors.OKCYAN + f'ploting...' + bcolors.ENDC)
    for metric in args.metric:
        if args.line:
            plot_format_func_line(metric, compond_group_list, group2_list, run_set, args.dirname,args.overlay, args.postfix)
        else:
            if 'area_breakdown' in metric:
                if args.pie:
                    plot_pie_area(metric, compond_group_list, group2_list, run_set, args.dirname,args.overlay, args.postfix)
                else:
                    plot_stackedbar_area(metric, compond_group_list, group2_list, run_set, args.dirname,args.overlay, args.postfix)
            elif 'energy_breakdown' in metric:
                if args.pie:
                    plot_pie_energy(metric, compond_group_list, group2_list, run_set, args.dirname,args.overlay, args.postfix)
                else:
                    plot_stackedbar_energy(metric, compond_group_list, group2_list, run_set, args.dirname,args.overlay, args.postfix)
            elif 'power_breakdown' in metric:
                plot_stackedbar_energy(metric, compond_group_list, group2_list, run_set, args.dirname,args.overlay, args.postfix)
            else:
                plot_format_func(metric, compond_group_list, group2_list, run_set, args.dirname, args.postfix)