import sys, os
sys.path.append(os.getcwd())
from utils import yaml_load, bcolors
import argparse
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
import os
import re
import math

def construct_argparser():
    parser = argparse.ArgumentParser(description='Roofline')
    parser.add_argument('-p',
                        '--point',
                        nargs='+',
                        help='point on roofline',
                        default=[
                            'arch_systolic_edge/systolic_edge_8x8_int8', \
                            'arch_systolic_edge/systolic_edge_16x16_int8', \
                            'arch_systolic_edge/systolic_edge_32x32_int8', \
                            'arch_tlut_intra/tlut_intra_1024_int8_t1',\
                            'arch_tlut_intra/tlut_intra_1024_int8_t2',\
                            'arch_tlut_intra/tlut_intra_1024_int8_t4',\
                            'arch_tlut_intra/tlut_intra_1024_int8_t8',\
                            'arch_tlut_inter/tlut_inter_1024_int8_t1',\
                            'arch_tlut_inter/tlut_inter_1024_int8_t2',\
                            'arch_tlut_inter/tlut_inter_1024_int8_t4',\
                            'arch_tlut_inter/tlut_inter_1024_int8_t8'
                        ],
                        )
    parser.add_argument('-w',
                        '--workload',
                        nargs='+',
                        help='workload to plot',
                        default=['resnet50', 'bert_base_uncased','dlrm','RNNT','ssd300_vgg16','UNet'],
                        choices=['resnet50', 'bert_base_uncased','dlrm','RNNT','ssd300_vgg16','UNet']
                        )
    parser.add_argument('-f',
                        '--file',
                        help='file name',
                        default=None)
    parser.add_argument('-c',
                        '--cfg',
                        help='perf normalized to',
                        choices=[None,'per_area','per_power','per_energy', 'rt', 'area', 'energy', 'power', 'util', 'perf_norm'],
                        default=None)

    return parser

if __name__ == "__main__":
    parser = construct_argparser()
    args = parser.parse_args()

    # peak perf lambda functions
    get_sys_peak_flops_per_sec = lambda arch_dict, cycle:(arch_dict['architecture']['compu']['num_instances'][0] * \
        arch_dict['architecture']['compu']['num_instances'][1] * 2.0 + \
        arch_dict['architecture']['compu']['num_instances'][1]) / cycle * \
        arch_dict['architecture']['frequency'] * 10**6
    get_tlut_peak_flops_per_sec = lambda arch_dict, chunk_size, data_length: 2 * \
        arch_dict['architecture']['compu']['num_instances'][0] * \
        (float(chunk_size)/data_length) * arch_dict['architecture']['frequency'] * 10**6

    assert args.file != None or len(args.point) != 0, 'no point specified'
    if args.file != None: # override args.point
        # parsing list of run names
        with open(args.file) as file:
            args.point = [line.rstrip() for line in file]
    print('*** points: ', args.point)
    print('*** metric: ', args.cfg)
    
    for wl in args.workload:
        print('*** workload: ',wl)

        # ---------plot----------
        fig, ax = plt.subplots(figsize=(8,5))
        ax.grid(color='grey', linestyle='--', linewidth=0.5, axis='both')

        # construct run set
        run_set = []
        postfix = ''
        
        n_max = 0
        n_set = set()
        for prefix in args.point:
            prefix = prefix + '_' + wl
            dir_append = prefix.split('/')[0]
            prefix = prefix.split('/')[1]
            myrootdir = f'runs/{dir_append}/'
            # print(myrootdir, prefix);exit()
            dir_set = []
            reg_name = re.compile(re.escape(prefix))# + r'_n\d{1,4}' + re.escape(postfix))
            dir_name = next(os.walk(myrootdir))[1]
            for dirnames in dir_name:
                if reg_name.match(dirnames):
                    # print('EXTRAPOLATING ', dirnames)
                    n = int(dirnames.split('_n')[1].split('_')[0])
                    if n > n_max: n_max = n
                    if n not in n_set: n_set.add(n)
                    dirnames = f'{dir_append}/{dirnames}'
                    tuple_ = (dirnames,n)
                    dir_set.append(tuple_)
            dir_set.sort(key=lambda y: y[1])
            run_set.append(dir_set)
            
        # print(run_set);exit()
        print('n_max', n_max)
        n_list = list(n_set)
        n_list.sort()
        # print(n_list)
        # print(len(n_list))
        assert n_max > 0


        width=1.0/(len(run_set)+1)
        n_str = np.arange(len(n_list))
        colors = [bcolors.yellow, bcolors.orange, bcolors.green, bcolors.gray, bcolors.blue, bcolors.red, bcolors.brown, \
            '#d3d3d3', '#bebebe', '#949494', '#808080','#7e7e7e','#616161']

        for i in range(len(run_set)):
            run = run_set[i]
            arr_perf = []
            arr_util_arch = []
            arr_util_impl = []
            eff_len = len(run)
            first_dir = True
            for dir in run:
                dir_ = f'runs/{dir[0]}/'

                # get peak flops/s
                if first_dir:
                    first_dir = False
                    arch_yml_file = dir_+'input/architecture.yaml'
                    arch_dict = yaml_load(arch_yml_file)
                    wl_yml_file = dir_+'input/workload.yaml'
                    wl_dict = yaml_load(wl_yml_file)['workload']
                    for layer in wl_dict.values():
                        cycle = layer['cycle']
                        break
                    if 'systolic' in dir[0]:
                        peak_flops_per_sec = get_sys_peak_flops_per_sec(arch_dict, cycle)
                    elif 'tlut' in dir[0]:
                        chunk_size = math.log2(arch_dict['architecture']['compu']['num_instances'][1])
                        if 'int8' in dir[0]:
                            data_length = 8
                        else: assert False, f'{run_set}'
                        peak_flops_per_sec = get_tlut_peak_flops_per_sec(arch_dict, chunk_size, data_length)
                    else: assert False, f'{dir[0]}'

                # ---- perf
                perf_yml_file = dir_+'output/performance/workloadperf.yaml'
                perf_dict = yaml_load(perf_yml_file)

                array_utilization_arch = perf_dict['overall']['utilization']['arch']
                array_utilization_impl = perf_dict['overall']['utilization']['impl']
                perf = perf_dict['overall']['flops per sec']
                rt =  perf_dict['overall']['runtime']['impl']

                arr_util_arch.append(array_utilization_arch)
                arr_util_impl.append(array_utilization_impl)

                # ----- cost
                if args.cfg != None:
                    if args.cfg in ['per_power', 'power', 'area', 'per_area', 'energy', 'per_energy']:
                        cost_yml_file = dir_+'output/cost/workloadcost.yaml'
                        cost_dict = yaml_load(cost_yml_file)

                    if args.cfg == 'per_power':
                        power = cost_dict['overall']['power']['total']['total']
                        perf_per_power = float(perf)/power
                        arr_perf.append(perf_per_power)
                    elif args.cfg == 'per_area':
                        area = cost_dict['overall']['area']['onchip']
                        perf_per_area = float(perf)/area
                        arr_perf.append(perf_per_area)
                    elif args.cfg == 'per_energy':
                        energy = cost_dict['overall']['energy']['total']['total']
                        perf_per_energy = float(perf)/energy
                        arr_perf.append(perf_per_energy)
                    elif args.cfg == 'rt':
                        arr_perf.append(rt)
                    elif args.cfg == 'area':
                        area = cost_dict['overall']['area']['onchip']
                        arr_perf.append(area)
                    elif args.cfg == 'energy':
                        energy = cost_dict['overall']['energy']['total']['total']
                        arr_perf.append(energy)
                    elif args.cfg == 'power':
                        power = cost_dict['overall']['power']['total']['total']
                        arr_perf.append(power)
                    elif args.cfg == 'util':
                        arr_perf.append(array_utilization_impl)
                    elif args.cfg == 'perf_norm':
                        arr_perf.append(perf/peak_flops_per_sec)
                    else: assert False
                else: 
                    arr_perf.append(perf)
            plt.bar(n_str[:eff_len]+width*i, np.array(arr_perf), width, label=f'{args.point[i]}', color=colors[i])
            
            # label_= f'peak flops/s for {args.point[i]}'
            label_ = ''
            if args.cfg == None:
                plt.axhline(y=peak_flops_per_sec, color=bcolors.brown, linestyle='--', label=label_)

        ax.legend(loc='upper left')
        plt.xticks(n_str+width*float(i)/2, 2**n_str)
        ax.set_xlabel('batch size')

        if args.cfg == None:
            ax.set_ylabel('Attainable Flops/s')
        elif args.cfg == 'per_power':
            ax.set_ylabel('Flops/s/w')
        elif args.cfg == 'per_area':
            ax.set_ylabel(r'$Flops/s/mm^{2}$')
        elif args.cfg == 'per_energy':
            ax.set_ylabel('Flops/s/energy')
        elif args.cfg == 'rt':
            ax.set_ylabel('runtime')
        elif args.cfg == 'area':
            ax.set_ylabel('area')
        elif args.cfg == 'power':
            ax.set_ylabel('power')
        elif args.cfg == 'energy':
            ax.set_ylabel('energy')
        elif args.cfg == 'util':
            ax.set_ylabel('Utilization %')
        elif args.cfg == 'perf_norm':
            ax.set_ylabel('Nomalized Flops/s')

        os.makedirs('plot/perf', exist_ok=True)
        os.makedirs('plot/perf/'+ wl, exist_ok=True)

        name_list = [name.split('/')[1] for name in args.point]
        name_list = [name.split('_int8')[0] for name in name_list]
        name_set = set()
        for name in name_list:
            name_set.add(name)
        name_list = list(name_set)
        name_list.sort()
        # print(name_list)
        namestr = '_'.join(name_list)
        namestr = namestr.replace('/','_')
        fig_name = f'plot/perf/{wl}/{namestr}_{args.cfg}.png'
        fig.tight_layout()
        plt.savefig(fig_name)
        print(bcolors.OKGREEN + f'Saved fig as {fig_name}' + bcolors.ENDC)

    