import sys, os
sys.path.append(os.getcwd())
from utils import yaml_load, bcolors
import argparse
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
import math
import os
from workload import conv

def construct_argparser():
    parser = argparse.ArgumentParser(description='Roofline')
    parser.add_argument('-p',
                        '--point',
                        nargs='+',
                        help='point on roofline',
                        default=['example_systolic_eyeriss'],
                        )
    parser.add_argument('-f',
                        '--file',
                        help='file name',
                        default=None)
    parser.add_argument('-o',
                        '--oname',
                        help='out file name',
                        default=None)

    return parser

if __name__ == "__main__":
    parser = construct_argparser()
    args = parser.parse_args()

    get_sys_peak_flops_per_sec = lambda arch_dict, cycle:(arch_dict['architecture']['compu']['num_instances'][0] * \
        arch_dict['architecture']['compu']['num_instances'][1] * 2.0) / cycle * \
        arch_dict['architecture']['frequency'] * 10**6
    get_tlut_peak_flops_per_sec = lambda arch_dict, chunk_size, data_length: 2 * \
        arch_dict['architecture']['compu']['num_instances'][0] * \
        (float(chunk_size)/data_length) * arch_dict['architecture']['frequency'] * 10**6

    assert args.file != None or len(args.point) != 0, 'no point specified'
    if args.file != None: # override args.point
        # parsing list of run names
        with open(args.file) as file:
            args.point = [line.rstrip() for line in file]
    print('*** points: ', args.point)

    # ---------plot----------
    fig, ax = plt.subplots(figsize=(14,9))
    ax.grid(color='grey', linestyle='--', linewidth=0.5, axis='both')
    # plt.xticks(rotation = 45) 
    ax.set_yscale('log')
    ax.set_xscale('log')
    # -------------------------
    tuple_set = set()
    names = []
    color=cm.rainbow(np.linspace(0,1,len(args.point)))
    index = 0
    x_list = []
    y_list = []
    y_list_nostall = []
    for p in args.point:
        dir_ = f'runs/{p}/'
        if 'systolic' in p: cycle = 1
        else: cycle = int(p.split('_c')[-1].split('_n')[0])

        # print(f'cycle={cycle}')
        perf_yml_file = dir_+'output/performance/workloadperf.yaml'
        perf_dict = yaml_load(perf_yml_file)
        array_utilization_arch = perf_dict['overall']['utilization']['arch']
        array_utilization_impl = perf_dict['overall']['utilization']['impl']
        arch_yml_file = dir_+'input/architecture.yaml'
        arch_dict = yaml_load(arch_yml_file)
        wl_yml_file = dir_+'input/workload.yaml'
        wl_dict = yaml_load(wl_yml_file)['workload']
        if 'wdram' in arch_dict['architecture'].keys(): wdram_exist = True
        else: wdram_exist = False

        # ==== compute theoretical max arithmetic intensity bounds ====
        # 1. theoretical max intensity bound: when three data tensors are read from DRAM exactly once
        # 2. n infinity max intensity bound: when tree tensors are read from DRAM exactly once AND N = inf (max reuse possible)
        total_flops = perf_dict['overall']['flops']
        # print(perf_dict["overall"])
        total_flops_calc = 0
        n_infinity_total_flops = 0
        min_dram_byte = 0
        input_output_byte = 0
        input_byte = 0
        output_byte = 0
        weight_byte = 0

        for layer in wl_dict.values():
            prob = conv("", layer)

            x=(prob.P-1)*prob.Hstride+prob.R
            y=(prob.Q-1)*prob.Wstride+prob.S

            i = prob.N * prob.X_valid * prob.Y_valid * prob.C
            w = prob.K * prob.R * prob.S * prob.C
            o = prob.N * prob.P * prob.Q * prob.K

            input_byte = i * arch_dict['architecture']['isram']['byte_per_word']
            output_byte = o * arch_dict['architecture']['osram']['byte_per_word']
            if 'wsram' in arch_dict['architecture'].keys():
                weight_byte = w * arch_dict['architecture']['wsram']['byte_per_word']
            else: # assume w align with input byte per word
                weight_byte = w * arch_dict['architecture']['isram']['byte_per_word']

            total_byte = input_byte + output_byte + weight_byte
            min_dram_byte += total_byte
            n_infinity_byte = ((input_byte + output_byte)/prob.N)
            input_output_byte += n_infinity_byte
            n_infinity_total_flops += (2 * prob.P * prob.Q * prob.R * prob.S * prob.C * prob.K)
            param_ = (2 * prob.P * prob.Q * prob.R * prob.S * prob.C * prob.K * prob.N)
            total_flops_calc += param_

        perf_dram_byte = 0
        for subdict in perf_dict.items():
            if subdict[0] == 'overall' or subdict[0] == 'frequency' or subdict[0] == 'technology': continue
            cur_odram_byte = subdict[1]['odram']['byte total']['rd'] + subdict[1]['odram']['byte total']['wr']
            perf_dram_byte += cur_odram_byte
            if wdram_exist:
                perf_dram_byte += subdict[1]['wdram']['byte total']['rd'] + subdict[1]['wdram']['byte total']['wr']

        # print('perf reported dram rd byte for this network', perf_dram_byte)
        # assert min_dram_byte <= perf_dram_byte, f'perf reported dram byte {perf_dram_byte} should not be less than sum of three tensor size {min_dram_byte}'
        if min_dram_byte > perf_dram_byte:
            print(bcolors.FAIL + f'perf reported dram byte {perf_dram_byte} should not be less than sum of three tensor size {min_dram_byte}' + bcolors.ENDC)
        theorecial_max_intensity_x = total_flops/min_dram_byte
        # n infinity bound
        n_infinity_max_intensity_x = n_infinity_total_flops / input_output_byte
        # print(f'b2 {n_infinity_max_intensity_x} = {n_infinity_total_flops}/{input_output_byte}')
        # print(f'b1 calc {theorecial_max_intensity_x} = {total_flops_calc}/{min_dram_byte}')
        # print(f'b1 perf {theorecial_max_intensity_x} = {total_flops}/{min_dram_byte}')
        assert n_infinity_max_intensity_x > theorecial_max_intensity_x, f'theoretical intensity bound {theorecial_max_intensity_x} should not exceed n infinity intensity bound {n_infinity_max_intensity_x}'
        # =============================================

        arch_type = arch_dict['architecture']['template']
        if 'systolic' in arch_type:
            peak_compute_perf = get_sys_peak_flops_per_sec(arch_dict, cycle)
        elif 'tlut' in arch_type:
            chunk_size = math.log2(arch_dict['architecture']['compu']['num_instances'][1])
            if 'int8' in p:
                data_length = 8
            elif 'int16' in p:
                data_length = 16
            elif 'bf16' in p:
                data_length = 8
            else: assert False
            peak_compute_perf = get_tlut_peak_flops_per_sec(arch_dict, chunk_size, data_length)
        else: assert False

        peak_memory_bw = arch_dict['architecture']['odram']['bandwidth']*10**9
            # 1.6*(10**9) # 1.6GHz dram clock
        # 12.8 GB/s
        if wdram_exist: 
            peak_memory_bw_odram = arch_dict['architecture']['wdram']['bits_per_chip'] / 8 * \
            arch_dict['architecture']['odram']['chip_per_dimm'] * \
            1.6*(10**9) # 1.6GHz dram clock
            peak_memory_bw += peak_memory_bw_odram

        intensity_x = perf_dict['overall']['flops per byte']['dram']
        gflops_per_s_y = perf_dict['overall']['flops per sec']/(10**9)
        gflops_per_s_y_nostall = perf_dict['overall']['flops'] / perf_dict['overall']['runtime']['arch'] / (10**9)
        # gflops_per_s_y_fullutil = gflops_per_s_y/(perf_dict['overall']['utilization']['arch'])*100

        p_split = p.split('_c')[-1]
        proper_name = p.replace(p_split,'')
        proper_name = proper_name.split('/')[1]
        names.append(proper_name)
        tuple_ = (peak_compute_perf, peak_memory_bw)
        if tuple_ not in tuple_set:
            tuple_set.add(tuple_)

            print('peak compute perf (Gops/s)', peak_compute_perf / 10**9)
            print('peak memory bw (Byte/s): ', peak_memory_bw)
            real_bw = perf_dict['overall']['bandwidth']['dram']['impl']
            print(f'real bw:    {real_bw}')
            # assert real_bw <= peak_memory_bw, f'real bw is greater than peak bw: {real_bw/peak_memory_bw}'
            if real_bw > peak_memory_bw: 
                print(bcolors.FAIL + f'real bw is greater than peak bw: {real_bw/peak_memory_bw}'+bcolors.ENDC)

            ridge_point_intensity = peak_compute_perf/peak_memory_bw
            start = min(math.floor(math.log2(ridge_point_intensity-0.01)),math.floor(math.log2(intensity_x)))
            end = math.ceil(math.log2(n_infinity_max_intensity_x))
            ind = np.linspace(start, end, end-start+1)
            flops_per_byte = [2**i for i in ind]
            x = flops_per_byte
            
            print(f'ridge point arithmetic intensity: {ridge_point_intensity}')
            print(f'n infinity arithmetic intensity: {n_infinity_max_intensity_x}')
            if ridge_point_intensity > n_infinity_max_intensity_x: print(bcolors.WARNING + 'memory bound!'+bcolors.ENDC)
            flops_per_sec = [min(intensity * peak_memory_bw, peak_compute_perf) for intensity in flops_per_byte]
            y = [fps/(10**9) for fps in flops_per_sec]
            
            ax.plot(x, y, 'k-', label=f'roofline {p_split}')

            # plt.xticks(x, x)
            # ax.set_xticks(x,minor=True)
            # ax.set_yticks(y,minor=True)
            # ax.xaxis.grid(True, which='major')
            # ax.yaxis.grid(True, which='major')
        
        if 'inter' in p: marker_ = '*'
        elif 'intra' in p: marker_ = 'x'
        else: marker_ = 'o'
        
        x_list.append(intensity_x)
        y_list.append(gflops_per_s_y)
        y_list_nostall.append(gflops_per_s_y_nostall)

        # === plot data points ====
        ax.plot(intensity_x, gflops_per_s_y, marker_, markersize=10, label=f'{p}', color=color[index])
        
        # === annotate data points with array utilization ===
        ax.annotate('%.1f, %.1f' % (array_utilization_arch,array_utilization_impl), xy=(intensity_x,gflops_per_s_y), textcoords='data')
        # ax.annotate('%.1f' % (array_utilization_impl), xy=(intensity_x,gflops_per_s_y), textcoords='data')

        ax.plot(intensity_x, gflops_per_s_y_nostall, '+', markersize=10, label='no stall', color=color[index])
        # ax.plot(intensity_x, gflops_per_s_y_fullutil, '1', markersize=10, label='full', color=color[index])
        ax.axvline(x=theorecial_max_intensity_x, ls=':', color=color[index], label='max arithmetic intensity')
        ax.axvline(x=n_infinity_max_intensity_x, ls='--', color='k', label='n infinity arithmetic intensity')
        index += 1

        if intensity_x == theorecial_max_intensity_x:
            assert min_dram_byte == perf_dram_byte, "?"

            # i_given_bytes = arch_dict['architecture']['isram']['byte_sram']
            # i_need_bytes = input_wd * arch_dict['architecture']['isram']['byte_per_word']
            # assert i_given_bytes >= i_need_bytes, \
            # f'isram insufficient, given {i_given_bytes} bytes but needs {i_need_bytes} bytes, ratio={i_need_bytes/i_given_bytes}'

            # assert arch_dict['architecture']['isram']['byte_sram'] >= input_wd * arch_dict['architecture']['isram']['byte_per_word']
            # assert arch_dict['architecture']['osram']['byte_sram'] >= output_wd * arch_dict['architecture']['osram']['byte_per_word']
            # w_given_bytes = arch_dict['architecture']['wsram']['byte_sram']
            # w_need_bytes = weight_wd * arch_dict['architecture']['wsram']['byte_per_word']
            # NOTE: weight stationary => even though wsram is insufficient, we can still load exactly once from DRAM
            # assert w_given_bytes >= w_need_bytes, \
            # f'wsram insufficient, given {w_given_bytes} bytes but needs {w_need_bytes} bytes, ratio={w_need_bytes/w_given_bytes}'
        else:
            print(bcolors.WARNING+'SRAM size too small to fit data tensors'+bcolors.ENDC)
        
    
    names_str = ''
    name_set = set()
    for name_ in names:
        if name_ not in name_set:
            names_str += '_' + name_
            name_set.add(name_)
    print(x_list)

    # ax.set_xticks(x_list,minor=False)
    # ax.set_yticks(y_list,minor=False)
    # ax.xaxis.grid(True, which='major')
    # ax.yaxis.grid(True, which='major')
    # ax.xaxis.grid(True, which='minor')
    # ax.yaxis.grid(True, which='minor')
    ax.fill_between(x_list, y1=y_list, y2=y_list_nostall, color='#e0e0e0', label='memory stall')

    ax.set_xlabel('Operational Intensity (Flops/Byte)')
    ax.set_ylabel('Attainable GFlops/s')
    # Shrink current axis by 20%
    box = ax.get_position()
    ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])

    # Put a legend to the right of the current axis
    ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    plt.title('roofline model for ' + str(name_set))

    os.makedirs('plot/rooflines', exist_ok=True)
    filename = args.oname.split('/')[-1]
    dir_ = args.oname.replace(filename, '')
    os.makedirs(f'plot/rooflines/{dir_}', exist_ok=True)
    fig_name = f'plot/rooflines/{args.oname}.png'
    # fig.tight_layout()
    plt.savefig(fig_name)
    print(f'Saved fig as {args.oname}')