import sys, os
sys.path.append(os.getcwd())
from utils import yaml_load, bcolors
import argparse
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
import os
import re
from scipy.stats import gmean
import math

def construct_argparser():
    parser = argparse.ArgumentParser(description='Roofline')
    parser.add_argument('-p',
                        '--point',
                        nargs='+',
                        help='point on roofline',
                        default=[\
                            'arch_systolic_edge/systolic_edge_8x8_int8',\
                            'arch_systolic_edge/systolic_edge_16x16_int8', \
                            'arch_systolic_edge/systolic_edge_32x32_int8', \
                            'arch_tlut_intra/tlut_intra_1024_int8_t1',\
                            'arch_tlut_intra/tlut_intra_1024_int8_t2',\
                            'arch_tlut_intra/tlut_intra_1024_int8_t4',\
                            'arch_tlut_intra/tlut_intra_1024_int8_t8',\
                            'arch_tlut_inter/tlut_inter_1024_int8_t1',\
                            'arch_tlut_inter/tlut_inter_1024_int8_t2',\
                            'arch_tlut_inter/tlut_inter_1024_int8_t4',\
                            'arch_tlut_inter/tlut_inter_1024_int8_t8'
                            ],
                        )
    parser.add_argument('-f',
                        '--file',
                        help='file name',
                        default=None)
    parser.add_argument('-n',
                        '--mlperf',
                        nargs='+',
                        help='networks in mlperf',
                        default=['RNNT', 'UNet', 'bert_base_uncased', 'dlrm', 'resnet50', 'ssd300_vgg16'],
                        )
    parser.add_argument('-g',
                        '--geomean',
                        default=False)

    parser.add_argument('-d',
                        '--data',
                        nargs='+',
                        help='perf to plot',
                        choices=['perf', 'perf_norm', 'util_impl', 'rt', \
                            'area_total', 'area_onchip', 'area_dram', 'area_sram', 'area_compute',\
                            'power_total', 'power_onchip', 'power_dram', 'power_sram', 'power_compute'],
                        default=['perf', 'perf_norm', 'util_impl', 'rt'],
                        )

    return parser

if __name__ == "__main__":
    parser = construct_argparser()
    args = parser.parse_args()

    # === peak perf lambda functions ===
    get_sys_peak_flops_per_sec = lambda arch_dict, cycle:(arch_dict['architecture']['compu']['num_instances'][0] * \
        arch_dict['architecture']['compu']['num_instances'][1] * 2.0 + \
        arch_dict['architecture']['compu']['num_instances'][1]) / cycle * \
        arch_dict['architecture']['frequency'] * 10**6
    get_tlut_peak_flops_per_sec = lambda arch_dict, chunk_size, data_length: 2 * \
        arch_dict['architecture']['compu']['num_instances'][0] * \
        (float(chunk_size)/data_length) * arch_dict['architecture']['frequency'] * 10**6
    # === peak perf lambda functions ===

    assert args.file != None or len(args.point) != 0, 'no point specified'
    if args.file != None: # override args.point
        # parsing list of run names
        with open(args.file) as file:
            args.point = [line.rstrip() for line in file]
    print('*** points: ', args.point)

    # construct run set
    
    n_max = 0
    n_set = set()
    for prefix in args.point:
        dir_append = prefix.split('/')[0]
        prefix = prefix.split('/')[1]
        myrootdir = f'runs/{dir_append}/'
        dir_set = []
        reg_name = re.compile(re.escape(prefix) + r'\S*_c\d{1,3}_n\d{1,4}')
        dir_name = next(os.walk(myrootdir))[1]
        for dirnames in dir_name:
            if reg_name.match(dirnames):
                n = int(dirnames.split('_n')[1].split('_')[0])
                if n > n_max: n_max = n
                if n not in n_set: n_set.add(n)
    print('n_max', n_max)
    n_list = list(n_set)
    n_list.sort()
    print('*** n: ', n_list)
    mlperf_or = '|'.join(args.mlperf)
    mlperf_or = '(' + mlperf_or + ')'

    run_set = []
    for prefix in args.point:
        # print(prefix)
        dir_append = prefix.split('/')[0]
        prefix = prefix.split('/')[1]
        myrootdir = f'runs/{dir_append}/'
        n_group = []
        for n in n_list:
            postfix = f'_n{n}'
            # print(postfix)
            reg_name = re.compile(re.escape(prefix) + '_' + mlperf_or + r'_c\d{1,3}' + re.escape(postfix) + r'$')
            dir_name = next(os.walk(myrootdir))[1]
            dir_set = []
            for dirnames in dir_name:
                if reg_name.match(dirnames):
                    dirnames = f'{dir_append}/{dirnames}'
                    dir_set.append(dirnames)
            dir_set.sort(key=lambda y: y[1])
            tuple_ = (n, dir_set)
            # print(tuple_)
            n_group.append(tuple_)
            # print(f'added {tuple_} into n group')
        run_set.append(n_group)
        # print(f'added {n_group} into run set')
    # print()
       
    # print(len(run_set))
    # print(run_set)
    # exit()
    
    # === collect data arrarys ===
    plot_perf_abs = []
    plot_perf_norm = []
    
    plot_area_total = []
    plot_area_sram = []
    plot_area_dram = []
    plot_area_compute = []
    plot_area_onchip = []

    plot_power_total = []
    plot_power_sram = []
    plot_power_dram = []
    plot_power_compute = []
    plot_power_onchip = []

    plot_rt = []

    plot_util_impl = []
    plot_util_arch = []

    for i in range(len(run_set)):
        print(i)
        first_dir = True
        arr_perf = []
        arr_gm_perf_norm = []
        arr_gm_util_arch = []
        arr_gm_util_impl = []
        arr_rt = []
        arr_area_onchip = []
        arr_area_compute = []
        arr_area_sram = []
        arr_area_dram = []
        arr_area_total = []
        arr_gm_power_onchip = []
        arr_gm_power_compute = []
        arr_gm_power_sram = []
        arr_gm_power_dram = []
        arr_gm_power_total = []
        
        run = run_set[i]

        eff_len = len(run)
        for n_group in run:
            n_arr_perf = []
            n_arr_perf_norm = []
            n_arr_util_arch = []
            n_arr_util_impl = []

            n_arr_power_total = []
            n_arr_power_onchip = []
            n_arr_power_sram = []
            n_arr_power_compute = []
            n_arr_power_dram = []

            total_flops = 0
            total_rt = 0
            rt_w_util_impl = 0
            print(n_group[1])
            for dir in n_group[1]:
                # print(dir)
                dir_ = f'runs/{dir}/'

                cost_yml_file = dir_+'output/cost/workloadcost.yaml'
                cost_dict = yaml_load(cost_yml_file)

                if first_dir:
                    first_dir = False
                    # --- extracting peak flops/s only once
                    arch_yml_file = dir_+'input/architecture.yaml'
                    arch_dict = yaml_load(arch_yml_file)
                    wl_yml_file = dir_+'input/workload.yaml'
                    wl_dict = yaml_load(wl_yml_file)['workload']
                    for layer in wl_dict.values():
                        cycle = layer['cycle']
                        break

                    if 'systolic' in dir:
                        peak_flops_per_sec = get_sys_peak_flops_per_sec(arch_dict, cycle)
                    elif 'tlut' in dir:
                        chunk_size = math.log2(arch_dict['architecture']['compu']['num_instances'][1])
                        if 'int8' in dir:
                            data_length = 8
                        else: assert False, f'{run_set}'
                        peak_flops_per_sec = get_tlut_peak_flops_per_sec(arch_dict, chunk_size, data_length)
                    else: assert False, f'{dir}'
                    print(f'{dir}: {peak_flops_per_sec}')

                    # --- extracting area only once
                    area_onchip = cost_dict['overall']['area']['onchip']
                    area_compute = cost_dict['overall']['area']['compute']
                    area_sram = cost_dict['overall']['area']['sram']
                    area_dram = cost_dict['overall']['area']['dram']
                    area_total = area_onchip + area_dram

                perf_yml_file = dir_+'output/performance/workloadperf.yaml'
                perf_dict = yaml_load(perf_yml_file)

                rt_impl = perf_dict['overall']['runtime']['impl']
                # print('rt: ',rt_impl)
                array_utilization_arch = perf_dict['overall']['utilization']['arch']
                array_utilization_impl = perf_dict['overall']['utilization']['impl']
                perf = perf_dict['overall']['flops per sec']
                perf_norm = perf/peak_flops_per_sec
                total_flops += perf_dict['overall']['flops']
                total_rt += rt_impl
                # print('rm acc: ',total_rt)
                rt_w_util_impl += rt_impl * array_utilization_impl

                power_total = cost_dict['overall']['power']['total']['total']
                power_onchip = cost_dict['overall']['power']['onchip']['total']
                power_compute = cost_dict['overall']['power']['compute']['total']
                power_dram = cost_dict['overall']['power']['dram']['total']
                power_sram = cost_dict['overall']['power']['sram']['total']

                n_arr_util_arch.append(array_utilization_arch)
                n_arr_util_impl.append(array_utilization_impl)
                n_arr_perf.append(perf)
                n_arr_perf_norm.append(perf_norm)

                n_arr_power_compute.append(power_compute)
                n_arr_power_onchip.append(power_onchip)
                n_arr_power_dram.append(power_dram)
                n_arr_power_sram.append(power_sram)
                n_arr_power_total.append(power_total)

            # print(n_arr_perf)

            # --- construct arr ----
            arr_gm_util_arch.append(gmean(n_arr_util_arch))
            arr_gm_util_impl.append(gmean(n_arr_util_impl))

            arr_gm_perf_norm.append(gmean(n_arr_perf_norm))
            arr_perf.append(total_flops / total_rt)

            arr_rt.append(total_rt)

            arr_area_onchip.append(area_onchip)
            arr_area_sram.append(area_sram)
            arr_area_dram.append(area_dram)
            arr_area_compute.append(area_compute)
            arr_area_total.append(area_total)

            arr_gm_power_compute.append(n_arr_power_compute)
            arr_gm_power_sram.append(n_arr_power_sram)
            arr_gm_power_onchip.append(n_arr_power_onchip)
            arr_gm_power_dram.append((n_arr_power_dram))
            arr_gm_power_total.append((n_arr_power_total))
        
        if args.point[0] in run[0][1][0]: 
            print(arr_gm_power_compute);exit()
        # --- construct plot arr ---
        plot_perf_abs.append(arr_perf)
        plot_perf_norm.append(arr_gm_perf_norm)

        plot_area_total.append(arr_area_total)
        plot_area_sram.append(arr_area_sram)
        plot_area_dram.append(arr_area_dram)
        plot_area_compute.append(arr_area_compute)
        plot_area_onchip.append(arr_area_onchip)

        plot_rt.append(arr_rt)

        plot_util_impl.append(arr_gm_util_impl)
        plot_util_arch.append(arr_gm_util_arch)


    width=1.0/(len(run_set)+1)
    n_str = np.arange(len(n_list))
    colors = [bcolors.yellow, bcolors.orange, bcolors.green, bcolors.gray, bcolors.blue, bcolors.red, bcolors.brown,\
        '#d3d3d3', '#bebebe', '#949494', '#808080','#7e7e7e','#616161']
    
    for metric in args.data:
        # print(j)
        print(f'*** metric: {metric}')
        # ---------plot----------
        fig, ax = plt.subplots(figsize=(8,5))
        ax.grid(color='grey', linestyle='--', linewidth=0.5, axis='both')

        if 'perf' == metric: arr_to_plot = plot_perf_abs
        elif 'perf_norm' == metric: arr_to_plot = plot_perf_norm
        elif 'util_impl' == metric: arr_to_plot = plot_util_impl
        elif 'rt' == metric: arr_to_plot = plot_rt
        elif 'area_compute' == metric: arr_to_plot = plot_area_compute
        elif 'area_sram' == metric: arr_to_plot = plot_area_sram
        elif 'area_dram' == metric: arr_to_plot = plot_area_dram
        elif 'area_onchip' == metric: arr_to_plot = plot_area_onchip
        elif 'power_compute' == metric: arr_to_plot = plot_power_compute
        elif 'power_sram' == metric: arr_to_plot = plot_power_sram
        elif 'power_dram' == metric: arr_to_plot = plot_power_dram
        elif 'power_onchip' == metric: arr_to_plot = plot_area_onchip

        for i in np.arange(len(arr_to_plot)):
            # print(arr_to_plot)
            plt.bar(n_str[:eff_len]+width*i, np.array(arr_to_plot), width, label=f'{args.point[i]}', color=colors[i])
            # label_= f'peak flops/s for {args.point[i]}'
            label_ = ''
            if 'perf' == metric: 
                plt.axhline(y=peak_flops_per_sec, color=bcolors.brown, linestyle='--', label=label_)

        ax.legend(loc='upper left')
        plt.xticks(n_str+width*float(i)/2, 2**n_str)
        plt.title('mlperf')
        ax.set_xlabel('batch size')

        if 'perf' == metric: ax.set_ylabel('Attainable Flops/s')
        elif 'util_impl' == metric: ax.set_ylabel('Array Utilization %')
        elif 'rt' == metric: ax.set_ylabel('Runtime (s)')
        elif 'perf_norm' == metric: ax.set_ylabel('Attainable Flops/s normalized to respective peak')
        elif 'area_compute' == metric: ax.set_ylabel(r'compute area $mm^{2}$')
        elif 'area_sram' == metric: ax.set_ylabel(r'sram area $mm^{2}$')
        elif 'area_dram' == metric: ax.set_ylabel(r'dram area $mm^{2}$')
        elif 'area_onchip' == metric: ax.set_ylabel(r'onchip area $mm^{2}$')
        elif 'power_compute' == metric: ax.set_ylabel('compute power normalized to baseline')
        elif 'power_sram' == metric: ax.set_ylabel('sram power normalized to baseline')
        elif 'power_dram' == metric: ax.set_ylabel('dram power normalized to baseline')
        elif 'power_onchip' == metric: ax.set_ylabel('onchip power normalized to baseline')

        os.makedirs('plot/mlperf', exist_ok=True)
        namestr = '_'.join(args.point)
        namestr = namestr.replace('/','_')
        fig_name = f'plot/mlperf/{namestr}_{metric}.png'
        fig.tight_layout()
        plt.savefig(fig_name)
        print(bcolors.OKGREEN + f'Saved fig as {fig_name}' + bcolors.ENDC)

    