from collections import OrderedDict
import sys, os
sys.path.append(os.getcwd())
from utils import yaml_load, bcolors,yaml_overwrite
import argparse
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from matplotlib.pyplot import cm
import os
import re
from scipy.stats import gmean
import math
import time
from plot_util import tlut_unique_component, hatch_map, \
    sys_breakdown_map, usys_breakdown_map, carat_breakdown_map, hatch_map_cg, \
    color_map, opac_map, \
    yml_metric_map, reduce_map, single_multiple_map, \
    partition_layerwise_reduction_map, axis_map, \
    mlperf_name_map, legend_map, clip_map, default_map, \
    ignore_vec_map, default_ignore_vec, default_ignore_vec_area,default_for_opac, \
    my_dpi, fig_h, fig_h_short, fig_w, size_tuple
PEAK = 'peak flops per sec'

# matplotlib settings
font = {'serif':'Helvetica Neue', 'size': 6}
matplotlib.rc('font', **font)
matplotlib.rcParams['hatch.linewidth'] = 0.3
matplotlib.rcParams['lines.linewidth'] = 0.3
matplotlib.rcParams['axes.linewidth'] = 0.5
matplotlib.rcParams['xtick.major.width'] = 0.5
matplotlib.rcParams['ytick.major.width'] = 0.5
matplotlib.rcParams['xtick.major.size'] = 2
matplotlib.rcParams['ytick.major.size'] = 2

def construct_argparser():
    parser = argparse.ArgumentParser(description='plot')
    parser.add_argument('-p',
                        '--point',
                        nargs='+',
                        help='config',
                        )


    parser.add_argument('--use_universal_baseline', dest='use_universal_baseline', action='store_true')
    parser.add_argument('--no-use_universal_baseline', dest='use_universal_baseline', action='store_false')
    parser.set_defaults(use_universal_baseline=False)

    parser.add_argument('-d',
                        '--dirname',
                        help='dir name',
                        default=None)

    parser.add_argument('-n',
                        '--mlperf',
                        nargs='+',
                        help='networks in mlperf\n',
                        default=['resnet50', 'ssd300_vgg16', 'UNet''bert_base_uncased', 'dlrm', 'RNNT', ]
                        )

    parser.add_argument('-m',
                        '--metric',
                        nargs='+',
                        help='perf to plot',
                        choices=['perf', 'perf_abs','perf_norm', 'util_impl', \
                            'rt_abs', 'rt_norm', \
                            'area_total', 'area_onchip','area_onchip_abs', 'area_dram', 'area_sram', 'area_compute',\
                            'area_partition', 'area_breakdown', 'energy_partition', 'energy_breakdown', \
                            'power_total', 'power_onchip','power_onchip_abs', 'power_dram', 'power_sram', 'power_compute',\
                            'energy_total', 'energy_total_abs', 'energy_onchip', 'energy_onchip_abs', 'energy_dram', 'energy_sram', 'energy_compute',\
                            'throughput/area', 'throughput/energy','throughput/energy_abs', 'throughput/power',\
                            'throughput/power_abs',\
                            'flops/area', 'flops/energy', 'flops/power'
                            ]
                        )
    return parser

def gen_read_sets(data):
    print('*** metric: ', data)
    perf_read = set()
    cost_read = set()
    other = set()
    cost_layer_read = set()

    for metric in data:
        if metric == 'perf' or metric == 'perf_norm' or metric == 'perf_abs':
            perf_read.add('flops per sec')
        if metric == 'perf_norm':
            perf_read.add(PEAK)
            other.add(metric)
        if metric == 'util_impl':
            perf_read.add('utilization/impl')
        if metric == 'rt_abs' or metric == 'rt_norm':
            perf_read.add('runtime/impl')

        if metric == 'power_total':
            cost_read.add('power/total/total')
            other.add(metric)
        if metric == 'power_onchip':
            cost_read.add('power/onchip/total')
            other.add(metric)
        if metric == 'power_sram':
            cost_read.add('power/sram/total')
        if metric == 'power_dram':
            cost_read.add('power/dram/total')
        if metric == 'power_compute':
            cost_read.add('power/compute/total')
        
        if metric == 'area_total' or metric == 'area_onchip' or metric == 'area_onchip_abs':
            cost_read.add('area/onchip')
        if metric == 'area_sram':
            cost_read.add('area/sram')
        if metric == 'area_total' or metric == 'area_dram':
            cost_read.add('area/dram')
        if metric == 'area_compute':
            cost_read.add('area/compute')
        if metric == 'area_total':
            other.add(metric)

        if metric == 'energy_total' or metric == 'energy_total_abs':
            cost_read.add('energy/total/total')
        if metric == 'energy_onchip':
            cost_read.add('energy/onchip/total')
            other.add(metric)
        if metric == 'energy_sram':
            cost_read.add('energy/sram/total')
        if metric == 'energy_dram':
            cost_read.add('energy/dram/total')
        if metric == 'energy_compute':
            cost_read.add('energy/compute/total') 
        
        if metric == 'flops/energy':
            perf_read.add('flops')
            cost_read.add('energy/onchip/total')
            other.add(metric)
        if metric == 'flops/area':
            perf_read.add('flops')
            cost_read.add('area/onchip')
            cost_read.add('area/sram')
            cost_read.add('area/dram')
            cost_read.add('area/compute')
            other.add(metric)
        if metric == 'flops/power':
            perf_read.add('flops')
            cost_read.add('power/onchip/total')
            cost_read.add('power/sram/total')
            other.add(metric)

        if metric == 'throughput/energy' or metric == 'throughput/energy_abs':
            perf_read.add('flops per sec')
            cost_read.add('energy/onchip/total')
            other.add(metric)
        if metric == 'throughput/area':
            perf_read.add('flops per sec')
            cost_read.add('area/onchip')
            cost_read.add('area/sram')
            cost_read.add('area/dram')
            cost_read.add('area/compute')
            other.add(metric)
        if metric == 'throughput/power' or metric == 'throughput/power_abs':
            perf_read.add('flops per sec')
            cost_read.add('power/onchip/total')
            other.add(metric)
        
        if metric == 'area_partition' or metric == 'area_breakdown':
            cost_read.add('area/onchip')
            cost_layer_read.add('compu/area')
            cost_layer_read.add('ififo/area')
            cost_layer_read.add('wfifo/area')
            cost_layer_read.add('ofifo/area')
            cost_layer_read.add('oaccu/area')
            cost_layer_read.add('isram/area')
            cost_layer_read.add('wsram/area')
            cost_layer_read.add('osram/area')
            cost_layer_read.add('waccu/area')
            cost_layer_read.add('itemp/area')
            cost_layer_read.add('osmux/area')
            cost_layer_read.add('bpipe/area')

        if metric == 'energy_partition' or metric == 'energy_breakdown':
            cost_read.add('energy/onchip/total')
            cost_layer_read.add('compu/energy/total')
            cost_layer_read.add('ififo/energy/total')
            cost_layer_read.add('wfifo/energy/total')
            cost_layer_read.add('ofifo/energy/total')
            cost_layer_read.add('oaccu/energy/total')
            cost_layer_read.add('isram/energy/total')
            cost_layer_read.add('wsram/energy/total')
            cost_layer_read.add('osram/energy/total')
            cost_layer_read.add('waccu/energy/total')
            cost_layer_read.add('itemp/energy/total')
            cost_layer_read.add('osmux/energy/total')
            cost_layer_read.add('bpipe/energy/total')
            # TODO: add more

    if len(perf_read) == 0 and len(cost_read) == 0 and len(other) == 0 and len(cost_layer_read) == 0: exit()
    return perf_read, cost_read, other, cost_layer_read

def read_perf_cost(run_set, perf_read, cost_read, other, cost_layer_read, ref_arch_name, use_universal_baseline):
    first_arch_in_all = True
    for group_dict in run_set.values():
        first_arch_in_group = True
        for arch_tuple in group_dict.items():
            arch = arch_tuple[1]
            # print(arch_tuple[0])
            if first_arch_in_all == True:
                first_arch_in_all = False
                ref_arch = arch
            # optional update if not universalbaseline
            if first_arch_in_group == True and use_universal_baseline == False: 
                first_arch_in_group = False
                ref_arch = arch
            
                
            for net_tuple in arch.items():
                net_name = net_tuple[0]
                net_dict = net_tuple[1]
                dir = net_dict['dir']
                # print(dir)
                if len(perf_read) != 0:
                    # process performance
                    perf_yml_file = f'runs/{dir}'+'/output/performance/workloadperf.summary.yaml'
                    perf_dict = yaml_load(perf_yml_file)
                    for perf in perf_read:
                        if perf == PEAK:
                            arch_yml_file = f'runs/{dir}'+'/input/architecture.yaml'
                            arch_dict = yaml_load(arch_yml_file)
                            wl_yml_file = f'runs/{dir}'+'/input/workload.yaml'
                            wl_dict = yaml_load(wl_yml_file)['workload']
                            for layer in wl_dict.values():
                                cycle = layer['cycle']
                                break
                            if 'ystolic' in dir: peak = get_sys_peak_flops_per_sec(arch_dict, cycle)
                            elif 'carat' in dir and 'i1' in dir:
                                chunk_size = math.log2(arch_dict['architecture']['compu']['num_instances'][1]) + 1
                                # print(chunk_size)
                                if 'fp8' in dir:
                                    data_length = 4
                                else: assert False
                                peak = get_tlut_peak_flops_per_sec(arch_dict, chunk_size, data_length)
                            elif 'carat' in dir: 
                                chunk_size = math.log2(arch_dict['architecture']['compu']['num_instances'][1])
                                if 'fp8' in dir:
                                    data_length = 4
                                else: assert False
                                peak = get_tlut_peak_flops_per_sec(arch_dict, chunk_size, data_length)
                            else: assert False, print(f'{dir}')
                            net_dict[perf] = peak
                            continue

                        level = perf.split('/')
                        thing = perf_dict['overall']
                        for l in level:
                            thing = thing[l]
                        net_dict[perf] = thing
                        net_dict[perf + '/norm'] = thing / ref_arch[net_name][perf]
                
                if len(cost_read) != 0:
                    # process cost
                    cost_yml_file = f'runs/{dir}'+'/output/cost/workloadcost.summary.yaml'
                    cost_dict = yaml_load(cost_yml_file)
                    for cost in cost_read:
                        level = cost.split('/')
                        thing = cost_dict['overall']
                        for l in level:
                            thing = thing[l]
                        if cost == 'area/onchip' and 'uSys' in arch_tuple[0]:
                            print(bcolors.FAIL+f'usys {cost}'+bcolors.ENDC)
                            thing -= cost_dict['overall']['area']['sram']
                        if cost == 'energy/onchip/total' and 'uSys' in arch_tuple[0]:
                            print(bcolors.FAIL+f'usys {cost}'+bcolors.ENDC)
                            thing -= cost_dict['overall']['energy']['sram']['total']
                        if cost == 'power/onchip/total' and 'uSys' in arch_tuple[0]:
                            print(bcolors.FAIL+f'usys {cost}'+bcolors.ENDC)
                            thing -= cost_dict['overall']['power']['sram']['total']
                        if cost == 'energy/total/total' and 'uSys' in arch_tuple[0]:
                            print(bcolors.FAIL+f'usys {cost}'+bcolors.ENDC)
                            thing -= cost_dict['overall']['energy']['sram']['total']
                        net_dict[cost] = thing
                        net_dict[cost + '/norm'] = thing / ref_arch[net_name][cost]
                
                if len(cost_layer_read) != 0:
                    # process cost partition
                    cost_layer_yml_file = f'runs/{dir}'+'/output/cost/workloadcost.yaml'
                    cost_layer_dict = yaml_load(cost_layer_yml_file)
                    for dict_ in cost_layer_dict.items():
                        if dict_[0] in ['technology', 'frequency', 'overall']:
                            continue
                        else: break

                    # init breakdown record
                    for cost in cost_layer_read:
                        area_or_energy = cost.split('/')[1]
                        for breakdown in hatch_map_cg.keys():
                            cost_bd = breakdown + f'/{area_or_energy}'
                            if area_or_energy == 'energy':
                                cost_bd += '/total'
                            net_dict[cost_bd] = 0

                    for cost in cost_layer_read:
                        level = cost.split('/')
                        partition = level[0]
                        area_or_energy = level[1]
                        
                        if partition in tlut_unique_component and 'ystolic' in dir:
                            continue

                        if area_or_energy == 'area':
                            thing = dict_[1]
                            total = cost_layer_dict['overall'][area_or_energy]['onchip']
                            for l in level:
                                thing = thing[l]
                            net_dict[cost] = thing
                            net_dict[cost + '/percent'] = thing / total
                            net_dict[cost + '/norm_percent'] = net_dict[cost + '/percent'] * net_dict[f'{area_or_energy}/onchip/norm']
                            if 'systolic' in dir: 
                                cost_cg = sys_breakdown_map[partition] + f'/{area_or_energy}'
                            elif 'uSystolic' in dir: 
                                if usys_breakdown_map[partition] != '':
                                    cost_cg = usys_breakdown_map[partition] + f'/{area_or_energy}'
                            elif 'carat' in dir: 
                                cost_cg = carat_breakdown_map[partition] + f'/{area_or_energy}'
                            else: assert False, print(dir)
                            # print(cost,cost_cg);exit()
                            net_dict[cost_cg] += thing
                            net_dict[cost_cg + '/percent'] = net_dict[cost_cg] / total
                            net_dict[cost_cg + '/norm_percent'] = net_dict[cost_cg + '/percent'] * net_dict[f'{area_or_energy}/onchip/norm']
                        
                        elif area_or_energy == 'energy':
                            net_dict[cost] = 0
                            for dict_ in cost_layer_dict.items():
                                if dict_[0] in ['technology', 'frequency', 'overall']:
                                    continue
                                else: 
                                    thing = dict_[1]
                                    for l in level:
                                        thing = thing[l]
                                    net_dict[cost] += thing
                            net_dict[cost + '/percent'] = net_dict[cost] / cost_layer_dict['overall'][area_or_energy]['onchip']['total']
                            net_dict[cost + '/norm_percent'] = net_dict[cost + '/percent'] * net_dict[f'{area_or_energy}/onchip/total/norm']
                            
                            if 'sys' in dir: cost_cg = sys_breakdown_map[partition] + f'/{area_or_energy}/total'
                            elif 'carat' in dir: cost_cg = carat_breakdown_map[partition] + f'/{area_or_energy}/total'
                            else: assert False
                            net_dict[cost_cg] += net_dict[cost]
                            net_dict[cost_cg + '/percent'] = net_dict[cost_cg] / cost_layer_dict['overall'][area_or_energy]['onchip']['total']
                            net_dict[cost_cg + '/norm_percent'] = net_dict[cost_cg + '/percent'] * net_dict[f'{area_or_energy}/onchip/total/norm']

                
                if len(other) != 0:
                    # process other metrics
                    for m in other:
                        if m == 'energy_onchip_abs' or m == 'energy_onchip': 
                            if 'uSys' in arch_tuple[0]:
                                net_dict[m] = net_dict['energy/onchip'] - net_dict['energy/sram']
                            else:
                                net_dict[m] = net_dict['energy/onchip']
                        if m == 'power_onchip_abs' or m == 'power_onchip': 
                            if 'uSys' in arch_tuple[0]:
                                net_dict[m] = net_dict['power/onchip'] - net_dict['power/sram']
                            else:
                                net_dict[m] = net_dict['power/onchip']
                        if m == 'perf_norm':
                            net_dict[m] = net_dict['flops per sec'] / net_dict[PEAK] * 100

                        if m == 'area_total': 
                            if 'uSys' in arch_tuple[0]:
                                net_dict[m] = net_dict['area/onchip'] - net_dict['area/sram'] + net_dict['area/dram']
                            else:
                                net_dict[m] = net_dict['area/onchip'] + net_dict['area/dram']

                        # NOTE: efficiency is in terms of onchip cost
                        if m == 'throughput/area': 
                            if 'uSys' in arch_tuple[0]:
                                net_dict[m] = net_dict['flops per sec'] / \
                                    (net_dict['area/onchip'] - net_dict['area/sram'])
                            else:
                                net_dict[m] = net_dict['flops per sec'] / \
                                    (net_dict['area/onchip'])
                        if m == 'throughput/power' or m == 'throughput/power_abs': 
                           net_dict[m] = net_dict['flops per sec'] / net_dict['power/onchip/total']
                        if m == 'throughput/energy' or m == 'throughput/energy_abs': 
                            net_dict[m] = net_dict['flops per sec'] / net_dict['energy/onchip/total']
                        
                        if m == 'flops/area': 
                            if 'uSys' in arch_tuple[0]:
                                net_dict[m] = net_dict['flops'] / \
                                    (net_dict['area/onchip'] - net_dict['area/sram'])
                            else: 
                                net_dict[m] = net_dict['flops'] / \
                                    (net_dict['area/onchip'])
                        if m == 'flops/power': 
                            if 'uSys' in arch_tuple[0]:
                                net_dict[m] = net_dict['flops'] / \
                                    (net_dict['power/onchip/total'] - net_dict['power/sram/total'])
                            else:
                                net_dict[m] = net_dict['flops'] / net_dict['power/onchip/total']
                        if m == 'flops/energy': 
                            if 'uSys' in arch_tuple[0]:
                                net_dict[m] = net_dict['flops'] / \
                                    (net_dict['energy/onchip/total'] - net_dict['energy/sram/total'])
                            else: 
                                net_dict[m] = net_dict['flops'] / net_dict['energy/onchip/total']

                        net_dict[m + '/norm'] = net_dict[m] / ref_arch[net_name][m]
                        
    return run_set

def plot_format_func(metric, group_list, run_set, dir_name, use_universal_baseline):
    if metric in yml_metric_map.keys():
        metric_in_yml_list = yml_metric_map[metric]
    else:
        assert False
    if type(metric_in_yml_list) == list: pass
    else:
        metric_in_yml_list = [metric_in_yml_list]
    print(metric_in_yml_list)

    for t in np.arange(len(metric_in_yml_list)):
        metric_in_yml = metric_in_yml_list[t]
        print(metric_in_yml)
        plot_arr = []
        name_arr = []
        for group_name in group_list:
            group_dict = run_set.get(group_name)
            arch_arr = []
            name_arr_ = []
            for arch_item in group_dict.items():
                arch_name = arch_item[0]
                arch_dict = arch_item[1]
                # print('arch name=',arch_name)
                # print('arch dict=',arch_dict)
                net_arr = []
                for net_item in arch_dict.items():
                    net_name = net_item[0]
                    net_dict = net_item[1]
                    perf = net_dict.get(metric_in_yml)
                    # print(f' net name={net_name}, perf={perf}')
                    net_arr.append(perf)
                # print(net_arr)
                gm = gmean(net_arr)
                arch_arr.append(gm)
                name_arr_.append(arch_name)
            # print(f'{group_name}, {arch_arr}')
            plot_arr.append(arch_arr)
            name_arr.append(name_arr_)
        print(name_arr)
        print(bcolors.OKBLUE +str(plot_arr)+bcolors.ENDC)


if __name__ == "__main__":
    parser = construct_argparser()
    args = parser.parse_args()

    print('*** mlperf: ', args.mlperf)
    print('*** dir: ', args.dirname)
    os.makedirs('plot/log', exist_ok=True)
    timestr = time.strftime("%Y%m%d-%H%M%S")
    if args.dirname != None:
        os.makedirs(f'plot/log/{args.dirname}', exist_ok=True)
        yml_name = f'./plot/log/{args.dirname}/plot_{timestr}.yml'
    else:
        yml_name = f'./plot/log/plot_{timestr}.yml'

    # === peak perf lambda functions ===
    get_sys_peak_flops_per_sec = lambda arch_dict, cycle:(arch_dict['architecture']['compu']['num_instances'][0] * \
        arch_dict['architecture']['compu']['num_instances'][1] * 2.0 + \
        arch_dict['architecture']['compu']['num_instances'][1]) / cycle * \
        arch_dict['architecture']['frequency'] * 10**6
    get_tlut_peak_flops_per_sec = lambda arch_dict, chunk_size, data_length: 2 * \
        arch_dict['architecture']['compu']['num_instances'][0] * \
        (float(chunk_size)/data_length) * arch_dict['architecture']['frequency'] * 10**6
    # === peak perf lambda functions ===

    print('*** points: ', args.point)

    if args.use_universal_baseline:
        print('*** using universal baseline')
    else:
        print('*** normalizing in terms of group (default)')

    # === construct run set ===
    print(bcolors.OKCYAN + f'Constructing run set...' + bcolors.ENDC)

    group_set = set()
    # group_list.sort(reverse=True)
    group_list = ['fp8']
    print(f'*** group by {group_list}')

    run_set = OrderedDict()
    
    # index 1: format
    for format_ in group_list:
        format_dict = OrderedDict()
        # index2: arch
        for point in args.point:           
            arch_name = point.split('/')[-1]
            dir_append = point.replace(arch_name, '')
            myrootdir = f'runs/{dir_append}/'
            group = OrderedDict()
            # index3: network
            for net in args.mlperf:
                reg_name = re.compile(re.escape(arch_name) + '_' + net + r'_c\d{1,3}_n'+'256')
                dir_name = next(os.walk(myrootdir))[1]
                for dirnames in dir_name:
                    if reg_name.match(dirnames):
                        dirnames = f'{dir_append}/{dirnames}'
                        dict_ = OrderedDict()
                        dict_['dir'] = dirnames
                        group[net] = dict_
            format_dict[arch_name] = group
        run_set[format_] = format_dict

    # === register needed results ===
    perf_read, cost_read, other, cost_layer_read = gen_read_sets(args.metric)

    # === fill in perf results ===
    print(bcolors.OKCYAN + f'Reading yml...' + bcolors.ENDC)
    ref_arch_name = args.point[0].split('/')[-1]
    run_set = read_perf_cost(run_set, perf_read, cost_read, other, cost_layer_read, ref_arch_name, args.use_universal_baseline)
    
    yaml_overwrite(yml_name, run_set) # TODO: remove for speed and space

    print(bcolors.OKCYAN + f'ploting...' + bcolors.ENDC)
    for metric in args.metric:
        plot_format_func(metric, group_list, run_set, args.dirname, args.use_universal_baseline)