from collections import OrderedDict
import sys, os
sys.path.append(os.getcwd())
from utils import yaml_load, bcolors,yaml_overwrite
import argparse
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from matplotlib.pyplot import cm
import os
import re
from scipy.stats import gmean
import math
import time
from loguru import logger
from plot_util import tlut_unique_component, hatch_map, ls_arr, marker_arr, \
    sys_breakdown_map, carat_breakdown_map, usys_breakdown_map, \
    hatch_map_cg, size_map, \
    color_map, opac_map, \
    yml_metric_map, reduce_map, single_multiple_map, \
    partition_layerwise_reduction_map, axis_map, \
    mlperf_name_map, legend_map, clip_map, default_map, \
    ignore_vec_map, default_ignore_vec, default_ignore_vec_area,default_for_opac, \
    my_dpi, fig_h, fig_h_short, fig_w, size_tuple, \
    y_ticks_map
# import plotly.express as px
# import plotly.io as pio
# pio.renderers.default = "notebook_connected"
# import plotly.graph_objects as go
# import plotly.figure_factory as ff

PEAK = 'peak flops per sec'

# matplotlib settings
font = {'serif':'Helvetica Neue', 'size': 6}
matplotlib.rc('font', **font)
matplotlib.rcParams['hatch.linewidth'] = 0.3
matplotlib.rcParams['lines.linewidth'] = 0.3
matplotlib.rcParams['axes.linewidth'] = 0.3
matplotlib.rcParams['axes.titlesize'] = 5
matplotlib.rcParams['xtick.major.width'] = 0.3
matplotlib.rcParams['xtick.labelsize'] = 5
matplotlib.rcParams['ytick.labelsize'] = 5
matplotlib.rcParams['ytick.major.width'] = 0.3
matplotlib.rcParams['xtick.major.size'] = 2
matplotlib.rcParams['ytick.major.size'] = 2
matplotlib.rcParams['legend.fontsize'] = 5
# matplotlib.rcParams['legend.handlelength'] = 2
matplotlib.rcParams['legend.labelspacing'] = 0.25


def construct_argparser():
    parser = argparse.ArgumentParser(description='plot')
    parser.add_argument('-p',
                        '--point',
                        nargs='+',
                        help='config',
                        )


    parser.add_argument('--use_universal_baseline', dest='use_universal_baseline', action='store_true')
    parser.add_argument('--no-use_universal_baseline', dest='use_universal_baseline', action='store_false')
    parser.set_defaults(use_universal_baseline=False)

    parser.add_argument('-d',
                        '--dirname',
                        help='dir name',
                        default=None)

    parser.add_argument('--line',
                        help='draw line plot',
                        action='store_true',
                        default=False)

    parser.add_argument('-n',
                        '--mlperf',
                        nargs='+',
                        help='networks in mlperf\n',
                        default=['resnet50', 'ssd300_vgg16', 'UNet''bert_base_uncased', 'dlrm', 'RNNT', ]
                        )

    parser.add_argument('-m',
                        '--metric',
                        nargs='+',
                        help='perf to plot',
                        choices=['perf', 'perf_norm', 'util_impl', \
                            'rt_abs', 'rt_norm', \
                            'area_total', 'area_onchip', 'area_dram', 'area_sram', 'area_compute',\
                            'area_partition', 'area_breakdown', 'energy_partition', 'energy_breakdown', \
                            'power_total', 'power_onchip', 'power_dram', 'power_sram', 'power_compute',\
                            'energy_total', 'energy_onchip', 'energy_dram', 'energy_sram', 'energy_compute',\
                            'throughput/area', 'throughput/energy', 'throughput/power',\
                            'flops/area', 'flops/energy', 'flops/power'
                            ]
                        )

    parser.add_argument('-q',
                        '--query',
                        nargs='+',
                        help='query name\n',
    )
    parser.add_argument('-qm',
                        '--qm',
                        nargs='+',
                        help='query metric\n',
                        default=['flops per sec', \
                            'throughput/energy', \
                            ]
                        )
    parser.add_argument('--overlay',
                        help='overlay benchmark on plot',
                        action='store_true',
                        default=False)
     
    parser.add_argument('--nval',
                        help='plot n=nval only',
                        choices=[
                            '1','2','4','8','16',\
                            '32','64','128','256','512',\
                            '1024','2048','4096'\
                        ],
                        default='256')

    return parser

def gen_read_sets(data):
    print('*** metric: ', data)
    perf_read = set()
    cost_read = set()
    other = set()
    cost_layer_read = set()

    for metric in data:
        if metric == 'perf' or metric == 'perf_norm':
            perf_read.add('flops per sec')
        if metric == 'perf_norm':
            perf_read.add(PEAK)
            other.add(metric)
        if metric == 'util_impl':
            perf_read.add('utilization/impl')
        if metric == 'rt_abs' or metric == 'rt_norm':
            perf_read.add('runtime/impl')

        if metric == 'power_total':
            cost_read.add('power/total/total')
        if metric == 'power_onchip':
            cost_read.add('power/onchip/total')
        if metric == 'power_sram':
            cost_read.add('power/sram/total')
        if metric == 'power_dram':
            cost_read.add('power/dram/total')
        if metric == 'power_compute':
            cost_read.add('power/compute/total')
        
        if metric == 'area_total' or metric == 'area_onchip':
            cost_read.add('area/onchip')
        if metric == 'area_sram':
            cost_read.add('area/sram')
        if metric == 'area_total' or metric == 'area_dram':
            cost_read.add('area/dram')
        if metric == 'area_compute':
            cost_read.add('area/compute')
        if metric == 'area_total':
            other.add(metric)

        if metric == 'energy_total':
            cost_read.add('energy/total/total')
        if metric == 'energy_onchip':
            cost_read.add('energy/onchip/total')
        if metric == 'energy_sram':
            cost_read.add('energy/sram/total')
        if metric == 'energy_dram':
            cost_read.add('energy/dram/total')
        if metric == 'energy_compute':
            cost_read.add('energy/compute/total') 
        
        if metric == 'flops/energy':
            perf_read.add('flops')
            cost_read.add('energy/onchip/total')
            other.add(metric)
        if metric == 'flops/area':
            perf_read.add('flops')
            cost_read.add('area/onchip')
            cost_read.add('area/sram')
            cost_read.add('area/dram')
            cost_read.add('area/compute')
            other.add(metric)
        if metric == 'flops/power':
            perf_read.add('flops')
            cost_read.add('power/onchip/total')
            other.add(metric)

        if metric == 'throughput/energy':
            perf_read.add('flops per sec')
            cost_read.add('energy/onchip/total')
            other.add(metric)
        if metric == 'throughput/area':
            perf_read.add('flops per sec')
            cost_read.add('area/onchip')
            cost_read.add('area/sram')
            cost_read.add('area/dram')
            cost_read.add('area/compute')
            other.add(metric)
        if metric == 'throughput/power':
            perf_read.add('flops per sec')
            cost_read.add('power/onchip/total')
            other.add(metric)
        
        if metric == 'area_partition' or metric == 'area_breakdown':
            cost_read.add('area/onchip')
            cost_layer_read.add('compu/area')
            cost_layer_read.add('ififo/area')
            cost_layer_read.add('wfifo/area')
            cost_layer_read.add('ofifo/area')
            cost_layer_read.add('oaccu/area')
            cost_layer_read.add('isram/area')
            cost_layer_read.add('wsram/area')
            cost_layer_read.add('osram/area')
            cost_layer_read.add('waccu/area')
            cost_layer_read.add('itemp/area')
            cost_layer_read.add('osmux/area')
            cost_layer_read.add('bpipe/area')

        if metric == 'energy_partition' or metric == 'energy_breakdown':
            cost_read.add('energy/onchip/total')
            cost_layer_read.add('compu/energy/total')
            cost_layer_read.add('ififo/energy/total')
            cost_layer_read.add('wfifo/energy/total')
            cost_layer_read.add('ofifo/energy/total')
            cost_layer_read.add('oaccu/energy/total')
            cost_layer_read.add('isram/energy/total')
            cost_layer_read.add('wsram/energy/total')
            cost_layer_read.add('osram/energy/total')
            cost_layer_read.add('waccu/energy/total')
            cost_layer_read.add('itemp/energy/total')
            cost_layer_read.add('osmux/energy/total')
            cost_layer_read.add('bpipe/energy/total')
            # TODO: add more

    if len(perf_read) == 0 and len(cost_read) == 0 and len(other) == 0 and len(cost_layer_read) == 0: exit()
    return perf_read, cost_read, other, cost_layer_read

def read_perf_cost(run_set, perf_read, cost_read, other, cost_layer_read, ref_arch_name, use_universal_baseline):
    # logger.debug(f'perf_read: {perf_read}, cost_read: {cost_read}, other: {other}, cost_layer_read: {cost_layer_read}')
    first_arch_in_all = True
    for group_dict in run_set.values():
        first_arch_in_group = True
        for arch_tuple in group_dict.items():
            # print(arch_tuple)
            arch = arch_tuple[1]
            if first_arch_in_all == True:
                first_arch_in_all = False
                ref_arch = arch
            # optional update if not universalbaseline
            if first_arch_in_group == True and use_universal_baseline == False: 
                first_arch_in_group = False
                ref_arch = arch
            
                
            for net_tuple in arch.items():
                net_name = net_tuple[0]
                net_dict = net_tuple[1]
                dir = net_dict['dir']
                if len(perf_read) != 0:
                    # process performance
                    perf_yml_file = f'runs/{dir}'+'/output/performance/workloadperf.summary.yaml'
                    perf_dict = yaml_load(perf_yml_file)
                    for perf in perf_read:
                        if perf == PEAK:
                            arch_yml_file = f'runs/{dir}'+'/input/architecture.yaml'
                            arch_dict = yaml_load(arch_yml_file)
                            wl_yml_file = f'runs/{dir}'+'/input/workload.yaml'
                            wl_dict = yaml_load(wl_yml_file)['workload']
                            for layer in wl_dict.values():
                                cycle = layer['cycle']
                                break
                            if 'ystolic' in dir: peak = get_sys_peak_flops_per_sec(arch_dict, cycle)
                            elif 'carat' in dir and 'i1' in dir:
                                chunk_size = math.log2(arch_dict['architecture']['compu']['num_instances'][1]) + 1
                                # print(chunk_size)
                                if 'fp8' in dir:
                                    data_length = 4
                                else: assert False
                                peak = get_tlut_peak_flops_per_sec(arch_dict, chunk_size, data_length)
                            elif 'carat' in dir and 'idle8' in dir: 
                                chunk_size = math.log2(arch_dict['architecture']['compu']['num_instances'][1])
                                if 'fp8' in dir:
                                    data_length = 4
                                else: assert False
                                peak = get_tlut_peak_flops_per_sec(arch_dict, chunk_size, data_length)
                            elif 'carat' in dir: 
                                chunk_size = math.log2(arch_dict['architecture']['compu']['num_instances'][1])
                                if 'fp8' in dir:
                                    data_length = 4
                                else: assert False
                                peak = get_tlut_peak_flops_per_sec(arch_dict, chunk_size, data_length)
                            else: assert False, f'{dir}'
                            net_dict[perf] = peak
                            continue

                        level = perf.split('/')
                        thing = perf_dict['overall']
                        for l in level:
                            thing = thing[l]
                        net_dict[perf] = thing
                        net_dict[perf + '/norm'] = thing / ref_arch[net_name][perf]
                        # logger.debug(net_dict)
                
                if len(cost_read) != 0:
                    # process cost
                    cost_yml_file = f'runs/{dir}'+'/output/cost/workloadcost.summary.yaml'
                    cost_dict = yaml_load(cost_yml_file)
                    for cost in cost_read:
                        level = cost.split('/')
                        thing = cost_dict['overall']
                        for l in level:
                            thing = thing[l]
                        net_dict[cost] = thing
                        net_dict[cost + '/norm'] = thing / ref_arch[net_name][cost]
                
                if len(cost_layer_read) != 0:
                    # process cost partition
                    cost_layer_yml_file = f'runs/{dir}'+'/output/cost/workloadcost.yaml'
                    cost_layer_dict = yaml_load(cost_layer_yml_file)
                    for dict_ in cost_layer_dict.items():
                        if dict_[0] in ['technology', 'frequency', 'overall']:
                            continue
                        else: break

                    # init breakdown record
                    for cost in cost_layer_read:
                        area_or_energy = cost.split('/')[1]
                        for breakdown in hatch_map_cg.keys():
                            cost_bd = breakdown + f'/{area_or_energy}'
                            if area_or_energy == 'energy':
                                cost_bd += '/total'
                            net_dict[cost_bd] = 0

                    for cost in cost_layer_read:
                        level = cost.split('/')
                        partition = level[0]
                        area_or_energy = level[1]
                        
                        if partition in tlut_unique_component and 'ystolic' in dir:
                            continue

                        if area_or_energy == 'area':
                            thing = dict_[1]
                            total = cost_layer_dict['overall'][area_or_energy]['onchip']
                            for l in level:
                                thing = thing[l]
                            net_dict[cost] = thing
                            net_dict[cost + '/percent'] = thing / total
                            net_dict[cost + '/norm_percent'] = net_dict[cost + '/percent'] * net_dict[f'{area_or_energy}/onchip/norm']
                            if 'systolic' in dir: 
                                cost_cg = sys_breakdown_map[partition] + f'/{area_or_energy}'
                            elif 'uSystolic' in dir: 
                                if usys_breakdown_map[partition] != '':
                                    cost_cg = usys_breakdown_map[partition] + f'/{area_or_energy}'
                            elif 'carat' in dir: 
                                cost_cg = carat_breakdown_map[partition] + f'/{area_or_energy}'
                            else: assert False, print(dir)
                            # print(cost,cost_cg);exit()
                            net_dict[cost_cg] += thing
                            net_dict[cost_cg + '/percent'] = net_dict[cost_cg] / total
                            net_dict[cost_cg + '/norm_percent'] = net_dict[cost_cg + '/percent'] * net_dict[f'{area_or_energy}/onchip/norm']
                        
                        elif area_or_energy == 'energy':
                            net_dict[cost] = 0
                            for dict_ in cost_layer_dict.items():
                                if dict_[0] in ['technology', 'frequency', 'overall']:
                                    continue
                                else: 
                                    thing = dict_[1]
                                    for l in level:
                                        thing = thing[l]
                                    net_dict[cost] += thing
                            net_dict[cost + '/percent'] = net_dict[cost] / cost_layer_dict['overall'][area_or_energy]['onchip']['total']
                            net_dict[cost + '/norm_percent'] = net_dict[cost + '/percent'] * net_dict[f'{area_or_energy}/onchip/total/norm']
                            
                            if 'systolic' in dir: 
                                cost_cg = sys_breakdown_map[partition] + f'/{area_or_energy}'
                            elif 'uSystolic' in dir: 
                                if usys_breakdown_map[partition] != '':
                                    cost_cg = usys_breakdown_map[partition] + f'/{area_or_energy}'
                            elif 'carat' in dir: 
                                cost_cg = carat_breakdown_map[partition] + f'/{area_or_energy}'
                            else: assert False, print(dir)
                            net_dict[cost_cg] += net_dict[cost]
                            net_dict[cost_cg + '/percent'] = net_dict[cost_cg] / cost_layer_dict['overall'][area_or_energy]['onchip']['total']
                            net_dict[cost_cg + '/norm_percent'] = net_dict[cost_cg + '/percent'] * net_dict[f'{area_or_energy}/onchip/total/norm']

                
                if len(other) != 0:
                    # process other metrics
                    for m in other:
                        if m == 'perf_norm':
                            net_dict[m] = net_dict['flops per sec'] / net_dict[PEAK] * 100

                        if m == 'area_total': 
                            net_dict[m] = net_dict['area/onchip'] + net_dict['area/dram']

                        # NOTE: efficiency is in terms of onchip cost
                        if m == 'throughput/area': 
                            net_dict[m] = net_dict['flops per sec'] / \
                                (net_dict['area/onchip'])
                        if m == 'throughput/power': 
                            net_dict[m] = net_dict['flops per sec'] / net_dict['power/onchip/total']
                        if m == 'throughput/energy': 
                            net_dict[m] = net_dict['flops per sec'] / net_dict['energy/onchip/total']
                        
                        if m == 'flops/area': 
                            net_dict[m] = net_dict['flops'] / \
                                (net_dict['area/onchip'])
                        if m == 'flops/power': 
                            net_dict[m] = net_dict['flops'] / net_dict['power/onchip/total']
                        if m == 'flops/energy': 
                            net_dict[m] = net_dict['flops'] / net_dict['energy/onchip/total']

                        net_dict[m + '/norm'] = net_dict[m] / ref_arch[net_name][m]
    return run_set

def plot_format_func(metric, compond_group_list, group2_list, run_set, dir_name):
    if metric in yml_metric_map.keys():
        metric_in_yml_list = yml_metric_map[metric]
    else:
        assert False
    if type(metric_in_yml_list) == list: pass
    else:
        metric_in_yml_list = [metric_in_yml_list]
    print(metric_in_yml_list)
    if len(metric_in_yml_list) > 1: use_hatch = True
    else: use_hatch = False

    if metric in legend_map.get(dir_name, default_map): size_tuple_local = size_tuple
    else: size_tuple_local = (fig_w, fig_h_short)
    fig, ax = plt.subplots(figsize=size_tuple_local, dpi=my_dpi)
    use_str = np.arange(len(compond_group_list))
    use2_str = np.arange(len(group2_list))

    prev_list = dict()
    color_bar = []
    lab_arch_arr = []
    for t in np.arange(len(metric_in_yml_list)):
        metric_in_yml = metric_in_yml_list[t]
        print(metric_in_yml)
        plot_arr = []
        name_arr = []
        for group_name in compond_group_list:
            group_dict = run_set.get(group_name)
            arch_arr = []
            name_arr_ = []
            for arch_item in group_dict.items():
                arch_name = arch_item[0]
                arch_dict = arch_item[1]
                # print('arch name=',arch_name)
                net_arr = []
                for net_item in arch_dict.items():
                    net_name = net_item[0]
                    net_dict = net_item[1]
                    perf = net_dict.get(metric_in_yml)
                    # print(f' net name={net_name}, perf={perf}')
                    net_arr.append(perf)
                gm = gmean(net_arr)
                arch_arr.append(gm)
                name_arr_.append(arch_name)
            # print(f'{group_name}, {arch_arr}')
            plot_arr.append(arch_arr)
            name_arr.append(name_arr_)
        width=1.0/(len(plot_arr[0])+1)

        print(plot_arr, name_arr);exit()
        color_bar_arr = []
        lab_arch_arr = []
        for i in np.arange(len(plot_arr[0])):
            arr_to_plot = []
            arr_name = []
            for j in np.arange(len(plot_arr)):
                arr_to_plot.append(plot_arr[j][i])
                arr_name.append(name_arr[j][i])
            print(arr_to_plot)
            print(arr_name)
            if 'sram1_to_8' in arr_name[0]: 
                lab_arch = '0.125x'
                op = 0.5
                color_ = bcolors.gray
            elif 'sram1_to_4' in arr_name[0]: 
                lab_arch = '0.25x'
                op = 0.99
                color_ = bcolors.orange
            elif 'sram1_to_2' in arr_name[0]: 
                lab_arch = '0.5x'
                op = 0.99
                color_ = bcolors.green
            elif 'sram2' in arr_name[0]: 
                lab_arch = '2x'
                op = 0.99
                color_ = bcolors.yellow
            elif 'sram1' in arr_name[0]: 
                lab_arch = '1x'
                op = 0.99
                color_ = bcolors.blue
            else: assert False
            prev = prev_list.get((j,i), None)
            color_bar_ = plt.bar(use_str+width*i, np.array(arr_to_plot), width, color=color_ , \
                alpha=op,\
                bottom=prev)

            plt.bar(use_str+width*i, np.array(arr_to_plot), width, color='none',edgecolor = 'k', linewidth=0.3, bottom=prev)

            if prev == None:
                prev_list[(j,i)] = arr_to_plot
            else:
                prev_arr = np.array(prev)
                prev_arr += np.array(arr_to_plot)
                prev_list[(j,i)] = list(prev_arr)
            
            color_bar_arr.append(color_bar_[0])
            lab_arch_arr.append(lab_arch)

    
    size_vec = (0,1.02,1,0.1)

    if metric in legend_map.get(dir_name, default_map):
        ax.legend(
            # red_hatch, lab_hatch_arr, \
            color_bar_arr, lab_arch_arr, \
            bbox_to_anchor=size_vec, loc="lower left",\
            mode="expand", borderaxespad=0, ncol=5)
    else: print(f'bw line501: no legend for {dir_name} {metric}')

    display_group_list = [c.replace('carat', 'Carat') for c in compond_group_list]
    display_group_list = [c.replace('systolic', 'SA') for c in display_group_list]
    display_group_list = [c.replace('_edge', '') for c in display_group_list]
    display_group_list = [c.replace('_128x16', '') for c in display_group_list]
    display_group_list = [c.replace('_256x16', '') for c in display_group_list]
    display_group_list = [c.replace('_int8', '') for c in display_group_list]
    display_group_list = [c.replace('_bf16', '') for c in display_group_list]
    display_group_list = [c.replace('_fp8', '') for c in display_group_list]
    tick = use_str+width*float(i)/2
    ax.set_xticks(tick)
    ax.set_xticklabels(display_group_list, rotation=0)

    ylab = axis_map[metric]
    if '(flop/s/nJ)' in ylab: ylab = 'Norm. energy\nefficiency\n(flop/s/nJ)'
    ax.set_ylabel(ylab)
    os.makedirs('plot/mlperf', exist_ok=True)
    metric_name = metric.replace('/', '_')
    metric_name = metric_name.replace(' ', '_')

    if dir_name == None:
        fig_name = f'plot/mlperf/{metric_name}'
    else:
        os.makedirs(f'plot/mlperf/{dir_name}', exist_ok=True)
        
        fig_name = f'plot/mlperf/{dir_name}/{metric_name}'

    fig_name += '.pdf'

    fig.tight_layout()
    plt.savefig(fig_name, bbox_inches='tight', dpi=my_dpi, pad_inches=0.02)
    print(bcolors.OKGREEN + f'Saved fig as {fig_name}' + bcolors.ENDC)

def plot_format_func_line(metric, compond_group_list, group2_list, run_set, dir_name, overlay):
    if metric in yml_metric_map.keys():
        metric_in_yml_list = yml_metric_map[metric]
    else:
        assert False
    if type(metric_in_yml_list) == list: pass
    else:
        metric_in_yml_list = [metric_in_yml_list]
    print(metric_in_yml_list)

    if metric in legend_map.get(dir_name, default_map): size_tuple_local = (3.3115, 1.3)
    else: size_tuple_local = (3.3115, 0.95)

    if size_map.get(dir_name, {}).get(metric): 
        size_tuple_local = size_map[dir_name][metric]

    fig, ax = plt.subplots(figsize=size_tuple_local, dpi=my_dpi)

    for t in np.arange(len(metric_in_yml_list)):
        metric_in_yml = metric_in_yml_list[t]
        print(metric_in_yml)
        plot_arr = []
        name_arr = []
        for group_name in compond_group_list:
            print(group_name)
            group_dict = run_set.get(group_name)
            arch_arr = []
            name_arr_ = []
            for arch_item in group_dict.items():
                arch_name = arch_item[0]
                arch_dict = arch_item[1]
                # print('arch name=',arch_name)
                net_arr = []
                for net_item in arch_dict.items():
                    net_name = net_item[0]
                    net_dict = net_item[1]
                    perf = net_dict.get(metric_in_yml)
                    # print(f' net name={net_name}, perf={perf}')
                    net_arr.append(perf)
                gm = gmean(net_arr)
                if overlay: arch_arr.append(net_arr)
                else: arch_arr.append(gm)
                name_arr_.append(arch_name)
            # print(f'{group_name}, {arch_arr}')
            plot_arr.append(arch_arr)
            name_arr.append(name_arr_)
        if not overlay:
            logger.debug(f'plot_arr={plot_arr}')
            assert(len(plot_arr[0]) == 3), print(len(plot_arr))
            for num in range(len(plot_arr)):
                arr = np.array(plot_arr[num])
                arr = arr.reshape(3, int(len(arr)/3)).tolist()
                print(arr)
                # lab_arr = ['bSA', 'uSA', 'Carat', 'Carat-i1']
                lab_arr = ['bSA', 'uSA', 'Carat', 'Carat-idle8']
                color_arr = [bcolors.gray, bcolors.blue, bcolors.orange, bcolors.cactus]
                for slice_no in range(len(arr)):
                    arr_ = arr[slice_no]
                    lab = lab_arr[num]+['-51.2', '-128','-256'][slice_no]
                    print(lab)
                    if 'uSA' in lab: mkr = '' 
                    else: mkr = marker_arr[slice_no]
                    plt.plot(
                        range(len(arr_)), arr_,
                        color=color_arr[num],
                        alpha=1-0.15*slice_no,
                        marker=mkr,
                        markersize=3,
                        linewidth=1.5,
                        linestyle=ls_arr[slice_no],
                        label=lab,
                    )
        else:
            print(plot_arr)
            for i in range(len(plot_arr[0][0])):
                regen = [[plot_arr[x][y][i] for y in range(len(plot_arr[x]))] for x in range(len(plot_arr))]
                print(regen)
                assert(len(regen[0]) == 15), print(len(regen))
                # lab_arr = ['bSA', 'uSA', 'Carat', 'Carat-i1']
                lab_arr = ['bSA', 'uSA', 'Carat', 'Carat-idle8']
                color_arr = [bcolors.gray, bcolors.blue, bcolors.orange, bcolors.cactus]
                for num in range(len(regen)):
                    arr = np.array(regen[num])
                    arr = arr.reshape(3, int(len(arr)/3)).tolist()
                    print(arr)
                    for slice_no in range(len(arr)):
                        arr_ = arr[slice_no]
                        lab = lab_arr[num]+['-51.2', '-128','-256'][slice_no]
                        print(lab)
                        if 'uSA' in lab: mkr = '' 
                        else: mkr = marker_arr[slice_no]
                        plt.plot(
                            range(len(arr_)), arr_,
                            color=color_arr[num],
                            alpha=1-0.15*slice_no,
                            marker=mkr,
                            markersize=3,
                            linewidth=1.5,
                            linestyle=ls_arr[slice_no],
                            label=lab,
                        )

    size_vec = (0,1.02,1,0.1)
    if metric in legend_map.get(dir_name, default_map):
        ax.legend(
            bbox_to_anchor=size_vec, loc="lower left",\
            mode="expand", borderaxespad=0, ncol=3)
    else: print(f'bw line601: no legend for {dir_name} {metric}')

    display_group_list = ['0.125x', '0.25x', '0.5x', '1x', '2x']
    ax.set_xticks(range(len(display_group_list)))
    ax.set_xticklabels(display_group_list, rotation=0)

    # display yticks based on y_ticks_map in util
    y_ticks_dict = y_ticks_map.get(dir_name, default_map)
    yt = y_ticks_dict.get(metric, default_map)
    if y_ticks_dict != default_map and \
        yt != default_map:
        ax.set_yticks(yt)
        ax.set_yticklabels(yt)

    ylab = axis_map[metric]
    if '(flop/s/nJ)' in ylab: ylab = 'Norm. energy\nefficiency\n(flop/s/nJ)'
    ax.set_ylabel(ylab)
    # ax.set_xlabel('SRAM size')
    os.makedirs('plot/mlperf', exist_ok=True)
    metric_name = metric.replace('/', '_')
    metric_name = metric_name.replace(' ', '_')

    if dir_name == None:
        fig_name = f'plot/mlperf/{metric_name}'
    else:
        os.makedirs(f'plot/mlperf/{dir_name}', exist_ok=True)
        
        fig_name = f'plot/mlperf/{dir_name}/{metric_name}'

    fig_name += '.pdf'

    fig.tight_layout()
    plt.savefig(fig_name, bbox_inches='tight', dpi=my_dpi, pad_inches=0.02)
    print(bcolors.OKGREEN + f'Saved fig as {fig_name}' + bcolors.ENDC)

def plot_format_func_line_constant_sram_size(metric, compond_group_list, group2_list, run_set, dir_name, overlay):
    if metric in yml_metric_map.keys():
        metric_in_yml_list = yml_metric_map[metric]
    else:
        assert False
    if type(metric_in_yml_list) == list: pass
    else:
        metric_in_yml_list = [metric_in_yml_list]
    print(metric_in_yml_list)

    if metric in legend_map.get(dir_name, default_map): size_tuple_local = (3.3115, 1.3)
    else: size_tuple_local = (3.3115, 0.95)

    if size_map.get(dir_name, {}).get(metric): 
        size_tuple_local = size_map[dir_name][metric]

    fig, ax = plt.subplots(figsize=size_tuple_local, dpi=my_dpi)

    for t in np.arange(len(metric_in_yml_list)):
        metric_in_yml = metric_in_yml_list[t]
        print(metric_in_yml)
        plot_arr = []
        name_arr = []
        for group_name in compond_group_list:
            print(group_name)
            group_dict = run_set.get(group_name)
            arch_arr = []
            name_arr_ = []
            for arch_item in group_dict.items():
                arch_name = arch_item[0]
                arch_dict = arch_item[1]
                # print('arch name=',arch_name)
                net_arr = []
                for net_item in arch_dict.items():
                    net_name = net_item[0]
                    net_dict = net_item[1]
                    perf = net_dict.get(metric_in_yml)
                    # print(f' net name={net_name}, perf={perf}')
                    net_arr.append(perf)
                gm = gmean(net_arr)
                if overlay: arch_arr.append(net_arr)
                else: arch_arr.append(gm)
                name_arr_.append(arch_name)
            # print(f'{group_name}, {arch_arr}')
            plot_arr.append(arch_arr)
            name_arr.append(name_arr_)
        if not overlay:
            logger.info("no overlay")
            logger.debug(f'plot_arr={plot_arr}')
            assert(len(plot_arr[0]) == 3), print(len(plot_arr))
            for num in range(len(plot_arr)):
                arr = np.array(plot_arr[num])
                # arr = arr.reshape(3, int(len(arr)/3)).tolist()
                print(arr)
                # lab_arr = ['bSA', 'uSA', 'Carat', 'Carat-i1']
                lab_arr = ['bSA', 'uSA', 'Carat', 'Carat-idle8']
                color_arr = [bcolors.gray, bcolors.blue, bcolors.orange, bcolors.cactus]
                # for slice_no in range(len(arr)):
                #     arr_ = arr[slice_no]
                #     lab = lab_arr[num]+['-51.2', '-128','-256'][slice_no]
                #     print(lab)
                #     if 'uSA' in lab: mkr = '' 
                #     else: mkr = marker_arr[slice_no]
                plt.plot(
                    range(len(arr)), arr,
                    color=color_arr[num],
                    # alpha=1-0.15*slice_no,
                    marker=marker_arr[0],
                    markersize=3,
                    linewidth=1.5,
                    linestyle=ls_arr[0],
                    label=lab_arr[num],
                )
        else:
            print(plot_arr)
            for i in range(len(plot_arr[0][0])):
                regen = [[plot_arr[x][y][i] for y in range(len(plot_arr[x]))] for x in range(len(plot_arr))]
                print(regen)
                assert(len(regen[0]) == 15), print(len(regen))
                # lab_arr = ['bSA', 'uSA', 'Carat', 'Carat-i1']
                lab_arr = ['bSA', 'uSA', 'Carat', 'Carat-idle8']
                color_arr = [bcolors.gray, bcolors.blue, bcolors.orange, bcolors.cactus]
                for num in range(len(regen)):
                    arr = np.array(regen[num])
                    arr = arr.reshape(3, int(len(arr)/3)).tolist()
                    print(arr)
                    for slice_no in range(len(arr)):
                        arr_ = arr[slice_no]
                        lab = lab_arr[num]+['-51.2', '-128','-256'][slice_no]
                        print(lab)
                        if 'uSA' in lab: mkr = '' 
                        else: mkr = marker_arr[slice_no]
                        plt.plot(
                            range(len(arr_)), arr_,
                            color=color_arr[num],
                            alpha=1-0.15*slice_no,
                            marker=mkr,
                            markersize=3,
                            linewidth=1.5,
                            linestyle=ls_arr[slice_no],
                            label=lab,
                        )

    size_vec = (0,1.02,1,0.1)
    if metric in legend_map.get(dir_name, default_map):
        ax.legend(
            bbox_to_anchor=size_vec, loc="lower left",\
            mode="expand", borderaxespad=0, ncol=3)
    else: print(f'bw line601: no legend for {dir_name} {metric}')

    display_group_list = ['51.2', '128', '256']
    ax.set_xticks(range(len(display_group_list)))
    ax.set_xticklabels(display_group_list, rotation=0)

    # display yticks based on y_ticks_map in util
    y_ticks_dict = y_ticks_map.get(dir_name, default_map)
    yt = y_ticks_dict.get(metric, default_map)
    if y_ticks_dict != default_map and \
        yt != default_map:
        ax.set_yticks(yt)
        ax.set_yticklabels(yt)

    ylab = axis_map[metric]
    if '(flop/s/nJ)' in ylab: ylab = 'Norm. energy\nefficiency\n(flop/s/nJ)'
    ax.set_ylabel(ylab)
    # ax.set_xlabel('SRAM size')
    os.makedirs('plot/mlperf', exist_ok=True)
    metric_name = metric.replace('/', '_')
    metric_name = metric_name.replace(' ', '_')

    if dir_name == None:
        fig_name = f'plot/mlperf/{metric_name}'
    else:
        os.makedirs(f'plot/mlperf/{dir_name}', exist_ok=True)
        
        fig_name = f'plot/mlperf/{dir_name}/{metric_name}'

    fig_name += '.pdf'

    fig.tight_layout()
    plt.savefig(fig_name, bbox_inches='tight', dpi=my_dpi, pad_inches=0.02)
    print(bcolors.OKGREEN + f'Saved fig as {fig_name}' + bcolors.ENDC)


def read_run_set(run_set, metric_str, arch_q_name):
    arch_arr = []
    gm_arr = []
    for (group_name, level1_dict) in run_set.items():
        # print(bcolors.FAIL+ f'{level1_dict}'+ bcolors.ENDC)
        for (arch_name, level2_dict) in level1_dict.items():
            # print(bcolors.WARNING+ f'{level2_dict}'+ bcolors.ENDC)
            value_arr = []
            for (net_name,level3_dict) in level2_dict.items():
                if net_name in args.mlperf: 
                    # print(bcolors.HEADER+ f'{level3_dict}'+ bcolors.ENDC)
                    for (metrics,values) in level3_dict.items():
                        if metrics == metric_str: 
                            # print(bcolors.OKBLUE+ f'{net_name}-{metrics}:{values}'+ bcolors.ENDC)
                            value_arr.append(values)
            # print(bcolors.OKCYAN+ f'{arch_name}-{metrics}:{gmean(value_arr)}'+ bcolors.ENDC)
            gm_arr.append(gmean(value_arr))
            arch_arr.append(arch_name)
    
    # print(bcolors.WARNING+ f'{arch_arr}'+ bcolors.ENDC)
    # print(bcolors.FAIL+ f'{gm_arr}'+ bcolors.ENDC)
    for n, v in zip(arch_arr, gm_arr):
        if n == arch_q_name: return v
    print(bcolors.FAIL + f'no such query name: {arch_q_name}' + bcolors.ENDC)
    exit()

def rescale(run_set):
    net_arr = []
    for key,val in run_set.items():
        for key_,val_ in val.items():
            for key__, val__ in val_.items():
                net = key__
                if net not in net_arr: net_arr.append(net)

    found_ref_val = False
    for key,val in run_set.items():
        batch = key
        for key_,val_ in val.items():
            arch = key_
            for key__, val__ in val_.items():
                net = key__
                for key___, val___ in val__.items():
                    metric_name = key___
                    metric_val = val___
                    if '/norm' in metric_name:
                        ind = net_arr.index(net)
                        if ind == 0 and not found_ref_val:
                            ref_val = val__[metric_name.replace('/norm', '')]
                            found_ref_val = True
                        scaled_val = val__[metric_name.replace('/norm', '')] / ref_val
                        run_set[key][key_][key__][metric_name] = scaled_val
                        print(f'ind={ind}, metric={metric_name}, val={metric_val}, ref={ref_val}, new={scaled_val}, {run_set[key][key_][key__][metric_name]}')
                        break
    return run_set

if __name__ == "__main__":
    parser = construct_argparser()
    args = parser.parse_args()

    print('*** mlperf: ', args.mlperf)
    print('*** dir: ', args.dirname)
    os.makedirs('plot/log', exist_ok=True)
    timestr = time.strftime("%Y%m%d-%H%M%S")
    if args.dirname != None:
        os.makedirs(f'plot/log/{args.dirname}', exist_ok=True)
        yml_name = f'./plot/log/{args.dirname}/plot_{timestr}.yml'
    else:
        yml_name = f'./plot/log/plot_{timestr}.yml'

    # === peak perf lambda functions ===
    get_sys_peak_flops_per_sec = lambda arch_dict, cycle:(arch_dict['architecture']['compu']['num_instances'][0] * \
        arch_dict['architecture']['compu']['num_instances'][1] * 2.0) / cycle * \
        arch_dict['architecture']['frequency'] * 10**6
    # get_sys_peak_flops_per_sec = lambda arch_dict, cycle:(arch_dict['architecture']['compu']['num_instances'][0] * \
    #     arch_dict['architecture']['compu']['num_instances'][1] * 2.0 + \
    #     arch_dict['architecture']['compu']['num_instances'][1]) / cycle * \
    #     arch_dict['architecture']['frequency'] * 10**6
    get_tlut_peak_flops_per_sec = lambda arch_dict, chunk_size, data_length: 2 * \
        arch_dict['architecture']['compu']['num_instances'][0] * \
        (float(chunk_size)/data_length) * arch_dict['architecture']['frequency'] * 10**6
    # === peak perf lambda functions ===

    print('*** points: ', args.point)

    if args.use_universal_baseline:
        print('*** using universal baseline')
    else:
        print('*** normalizing in terms of group (default)')

    # === construct run set ===
    print(bcolors.OKCYAN + f'Constructing run set...' + bcolors.ENDC)

    group_list = ['systolic', \
            'uSystolic', \
            'carat', \
            ]
    for p in args.point:
        if "i1" in p and "carat" in p:
            group_list.append('carat_i1')
            break
        if "idle8" in p and "carat" in p:
            group_list.append('carat_idle8')
            break

    # group2_list = ['sram1','sram2']
    group2_list = ['fp8']
    print(f'*** group by {group_list}')
    compond_group_list = []

    run_set = OrderedDict()
    
    # index 1: format
    for group2 in group2_list:
        for group_ in group_list:
            if 'uSys' in group_: group2 = 'int8'
            else: group2 = 'fp8'
            group_dict = OrderedDict()
            # index2: arch
            for point in args.point:
                if group_ not in point or group2 not in point: 
                    continue
                if group_ == "carat" and "i1" in point:
                    continue
                if group_ == "carat" and "idle8" in point:
                    continue
                arch_name = point.split('/')[-1]
                arch_name_ = point.split('/')[-1] + '_' + point.split('/')[1]
                dir_append = point.replace(arch_name, '')
                myrootdir = f'runs/{dir_append}/'
                print(myrootdir)
                group = OrderedDict()
                # index3: network
                for net in args.mlperf:
                    reg_name = re.compile(re.escape(arch_name) + '_' + net + r'_c\d{1,3}_n'+str(args.nval))
                    dir_name = next(os.walk(myrootdir))[1]
                    for dirnames in dir_name:
                        if reg_name.match(dirnames):
                            dirnames = f'{dir_append}/{dirnames}'
                            dict_ = OrderedDict()
                            dict_['dir'] = dirnames
                            group[net] = dict_
                group_dict[arch_name_] = group
            compond_name = group_+'_'+group2
            print(compond_name)
            compond_group_list.append(compond_name)
            run_set[compond_name] = group_dict

    # === register needed results ===
    perf_read, cost_read, other, cost_layer_read = gen_read_sets(args.metric)

    # === fill in perf results ===
    print(bcolors.OKCYAN + f'Reading yml...' + bcolors.ENDC)
    ref_arch_name = args.point[0].split('/')[-1]
    run_set = read_perf_cost(run_set, perf_read, cost_read, other, cost_layer_read, ref_arch_name, args.use_universal_baseline)
    
    if args.overlay:
        # rescale metric/norm
        print(bcolors.WARNING + f'rescale metric/norm' + bcolors.ENDC)
        run_set = rescale(run_set)

    yaml_overwrite(yml_name, run_set) # TODO: remove for speed and space
    
    # output number for table filling
    if args.query and args.qm: 
        for q in args.query:
            for m in args.qm:
                res = read_run_set(run_set, m, q)
                print(bcolors.HEADER + f'{m}-{q}: {res}' + bcolors.ENDC)

    print(bcolors.OKCYAN + f'ploting...' + bcolors.ENDC)
    for metric in args.metric:
        if args.line:
            plot_format_func_line_constant_sram_size(metric, compond_group_list, group2_list, run_set, args.dirname,args.overlay)
            # plot_format_func_line(metric, compond_group_list, group2_list, run_set, args.dirname,args.overlay)
        else:
            plot_format_func(metric, compond_group_list, group2_list, run_set, args.dirname)