import argparse
import os
import pathlib
import shutil
import itertools
import sys, os
sys.path.append(os.getcwd())
from utils import yaml_load, yaml_overwrite
import random
import math
from utils.utils import bcolors

def construct_argparser():
    parser = argparse.ArgumentParser(description='Generate Configuration')
    parser.add_argument('-a',
                        '--arch',
                        nargs='+',
                        choices=['systolic', 'carat', 'uSystolic', 
                            'systolic-noc', 'carat-noc', 'uSystolic-noc',
                            'carat-i1', 'carat-i1-noc',
                            'carat-idle8', 'carat-idle8-noc',
                            'cris', 'cris-noc',
                        ],
                        help='arch type',
                        default=['systolic'],
                        )
    parser.add_argument('-w',
                        '--workload',
                        nargs='+',
                        help='workload/network name',
                        choices=['resnet50', 'bertbase', 'dlrm', 'rnnt', 'ssd', 'alexnet', 'unet', 
                            'fc', 'conv'],
                        default=['resnet50', 'bertbase', 'dlrm', 'rnnt', 'ssd', 'unet'],
                        )

    parser.add_argument('-f',
                        '--format',
                        nargs='+',
                        help='data format',
                        choices=['fp8'],
                        default=['fp8'],
                        )
    parser.add_argument('--dram',
                        nargs='+',
                        help='dram bw',
                        type=float,
                        choices=[25.6, 51.2, 128, 256],
                        default=[128],
                        )
    parser.add_argument('--sram',
                        nargs='+',
                        type=float,
                        help='sram size',
                        choices=[0.125, 0.25, 0.5, 1, 2],
                        default=[1],
                        )
    parser.add_argument('-r',
                        '--dim1',
                        nargs='+',
                        type=int,
                        help='row dimension',
                        choices=[
                            # carat
                            1024,2048,4096,
                            16,32,64,128,256,512,
                            # (u)systolic
                            32, 64,
                            4,8,16,
                            ],
                        )

    parser.add_argument('-c',
                        '--dim2',
                        nargs='+',
                        type=int,
                        help='col dimension',
                        choices=[4,8,16,32,64,128,256,
                                 1 #cris,
                                 ],
                        default=[16],
                        )
    
    parser.add_argument('--noc1',
                        nargs='+',
                        type=int,
                        help='noc first dimension',
                        choices=[2,4,8],
                        )
    parser.add_argument('--noc2',
                        nargs='+',
                        type=int,
                        help='noc second dimension',
                        choices=[1,2,4,8],
                        )
    parser.add_argument('--nocbw',
                        nargs='+',
                        type=int,
                        help='noc bw',
                        choices=[64, 128, 256],
                        default=[128],
                        )
    parser.add_argument('-n',
                        '--batch_size',
                        nargs='+',
                        type=int,
                        help='batch size',
                        default=[1,2,4,8,16,32,64,128,256,512,1024],
                        )
    parser.add_argument(
                        '--conference',
                        type=str,
                        help='to which conference',
                        choices=['asplos2024'],
                        default='asplos2024',
                        )
    parser.add_argument(
                        '--tech_override',
                        type=int,
                        help='techonology override',
                        required=False,
                        default=32,
                        )
    return parser

def gen_single_config(arch:str, workload:str, dim1:int, dim2:int, format_:str, 
    dram:float, sram:float, batch_size:int, noc1:int, noc2:int, nocbw:int, conference:str, tech:int):
    # print(conference)
    arch = arch.replace('-','_')
    default_bank_ct = 32
    bank_ct = default_bank_ct
    if 'uSys' in arch:
        format_cache = 'int8_acc24'
    else: format_cache = format_+'_acc16'

    # if 'noc' in arch:
    #     format_cache = f'{noc1}x{noc2}_' + format_cache
    # print(format_cache)

    dram_bw = dram

    i_byte_per_row=None
    w_byte_per_row=None
    o_byte_per_row=None

    # translate format to cycle
    if 'uSystolic' in arch:
        cycle = 128
        if conference=='asplos2024': cycle=256
        byte = 1 # int8
        if conference != 'asplos2024':
            i_byte_per_row = max(2**int(math.log2(int(dim1) * byte / default_bank_ct / cycle)),1)
            w_byte_per_row = max(2**int(math.log2(int(dim2) * byte / default_bank_ct / cycle)),1)
            o_byte_per_row = i_byte_per_row

    elif 'systolic' in arch: 
        cycle = 1
        i_byte_per_row = None # none for default
        w_byte_per_row = None
        o_byte_per_row = None

        if 'fp8' in format_: byte = 1
        else: assert False
        if conference != 'asplos2024':
            i_byte_per_row = max(2**int(math.log2(int(dim1) * byte / default_bank_ct / cycle)),1)
            w_byte_per_row = max(2**int(math.log2(int(dim2) * byte / default_bank_ct / cycle)),1)
            o_byte_per_row = i_byte_per_row

    elif 'carat' in arch and 'i1' in arch:
        if 'fp8' in format_: 
            chunk = 4
            byte = 1 
            dw = 4
        else: assert False

        cycle=int((2**(chunk)) / 2 * (dw/chunk))
        i_byte_per_row = max(2**int(math.log2(int(dim1) * byte / default_bank_ct / cycle)),1)
        w_byte_per_row = max(2**int(math.log2(2**chunk * byte / default_bank_ct / cycle)),1)
        o_byte_per_row = i_byte_per_row

    elif 'carat' in arch and 'idle8' in arch:
        if 'fp8' in format_: 
            chunk = 4
            byte = 1 
            dw = 4
        else: assert False
        
        cycle=int((2**chunk) * (dw/chunk))
        i_byte_per_row = max(2**int(math.log2(int(dim1) * byte / default_bank_ct / cycle)),1)
        w_byte_per_row = max(2**int(math.log2(2**chunk * byte / default_bank_ct / cycle)),1)
        o_byte_per_row = i_byte_per_row

    elif 'carat' in arch:
        if conference == 'asplos2024':
            if 'fp8' in format_: 
                chunk = 4
                byte = 1 
                dw = 4
                cycle=int((2**(chunk-1)) * (dw/chunk))
            else: assert False
        else:
            if 'fp8' in format_: 
                chunk = 4
                byte = 1 
                dw = 4
                cycle=int((2**chunk) * (dw/chunk))
            else: assert False
        
            i_byte_per_row = max(2**int(math.log2(int(dim1) * byte / default_bank_ct / cycle)),1)
            w_byte_per_row = max(2**int(math.log2(2**chunk * byte / default_bank_ct / cycle)),1)
            o_byte_per_row = i_byte_per_row

    elif 'cris' in arch:
        cycle = 1
        i_byte_per_row = None # none for default
        w_byte_per_row = None
        o_byte_per_row = None

        if 'fp8' in format_: byte = 1
        else: assert False
        if conference != 'asplos2024':
            i_byte_per_row = max(2**int(math.log2(int(dim1) * byte / default_bank_ct / cycle)),1)
            w_byte_per_row = max(2**int(math.log2(int(dim2) * byte / default_bank_ct / cycle)),1)
            o_byte_per_row = i_byte_per_row

    else: assert False

    # print(f'cycle={cycle}, {(i_byte_per_row,w_byte_per_row,o_byte_per_row)}')
    if conference == 'asplos2024': 
        assert(i_byte_per_row==None and w_byte_per_row==None and o_byte_per_row==None)
        if 'carat' in arch: assert cycle==8
        if 'uSystolic' in arch: assert cycle==256
        if 'cris' in arch: assert cycle==1

    base_dir = ''
    base_dir_sys = 'arch_systolic'
    base_dir_carat = 'arch_carat'
    base_dir_carat_i1 = 'arch_carat_i1'
    base_dir_carat_idle8 = 'arch_carat_idle8'
    base_dir_usys = 'arch_uSystolic'
    base_dir_cris = 'arch_cris'

    if 'uSystolic' in arch:
        base_dir = base_dir_usys
    elif 'systolic' in arch:
        base_dir = base_dir_sys
    elif 'carat' in arch and 'i1' in arch:
        assert dim1==1024 or dim1==2048 or dim1==4096
        base_dir = base_dir_carat_i1
    elif 'carat' in arch and 'idle8' in arch:
        assert dim1==1024 or dim1==2048 or dim1==4096
        base_dir = base_dir_carat_idle8
    elif 'carat' in arch:
        if conference == 'asplos2024': 
            assert dim1==16 or dim1==32 or dim1==64 or dim1==128 or dim1==256 or dim1==512
        else:
            assert dim1==1024 or dim1==2048 or dim1==4096
        base_dir = base_dir_carat
    elif 'cris' in arch:
        if conference == 'asplos2024': 
            assert dim1==16 or dim1==32 or dim1==64 or dim1==128 or dim1==256
            assert dim2==1
        base_dir = base_dir_cris
    
    if 'noc' in arch:
        base_dir += '_noc'
        if 'carat' in arch and 'i1' in arch:
            base_dir = "arch_carat_noc_i1"
        if 'carat' in arch and 'idle8' in arch:
            base_dir = "arch_carat_noc_idle8"

    dram_int = int(dram)
    sram_int = sram
    if sram >= 1: sram_int = int(sram)
    else: sram_int = f'1_to_{int(1/sram)}'
    subdir = f'dram{dram_int}_sram{sram_int}'
    if 'noc' in arch: subdir += f'/bw{int(nocbw)}'
    os.makedirs(f'runs/{base_dir}/{subdir}', exist_ok=True)

    # parse dest dir name
    if 'carat' in arch or 'cris' in arch:
        if 'noc' in arch:
            dir_name = base_dir + "/" +  subdir + "/" +  f'{arch}{noc1}x{noc2}_{dim1}_{format_cache}_{workload}_c{cycle}_n{batch_size}'
        else:
            dir_name = base_dir + "/" +  subdir + "/" +  f'{arch}_{dim1}_{format_cache}_{workload}_c{cycle}_n{batch_size}'
    else: 
        if 'noc' in arch:
            dir_name = base_dir + "/" + subdir + "/" + f'{arch}{noc1}x{noc2}_{dim1}x{dim2}_{format_cache}_{workload}_c{cycle}_n{batch_size}'
        else:
            dir_name = base_dir + "/" + subdir + "/" + f'{arch}_{dim1}x{dim2}_{format_cache}_{workload}_c{cycle}_n{batch_size}'

    # check existance
    if os.path.exists(f'runs/{dir_name}'):
        # print(bcolors.WARNING+'config dir already exists'+bcolors.ENDC)
        #delete_dir(f'runs/{dir_name}/')
        os.makedirs(f'runs/{dir_name}', exist_ok=True)
        input_dir = f'runs/{dir_name}/input'
        os.makedirs(input_dir, exist_ok=True)
    else:
        os.makedirs(f'runs/{dir_name}')
        input_dir = f'runs/{dir_name}/input'
        os.makedirs(input_dir)
    
    # dest files
    wl_yaml_file = f'{input_dir}/workload.yaml'
    arch_yaml_file = f'{input_dir}/architecture.yaml'
    dest_prog = f'runs/program_perf_cost.yaml' # both perf and cost

    # dest dir
    if 'carat' in arch:
        dest_arch_dir = base_dir + '/' + f'{dim1}_{format_cache}'
        # arch = 'tlut-inter'
    elif 'ystolic' in arch: 
        dest_arch_dir = base_dir + '/' + f'{dim1}x{dim2}_{format_cache}'
        # arch = 'systolic'
    elif 'cris' in arch:
        dest_arch_dir = base_dir + '/' + f'{dim1}_{format_cache}'
    else:
        assert False

    shutil.copy2(f'runs/{dest_arch_dir}/input/architecture.yaml', arch_yaml_file)
    shutil.copy2(dest_prog, f'{input_dir}/program.yaml')
    shutil.copy2(f'workload/{workload}/workload.yaml', wl_yaml_file)
    
    # === arch file ===
    # override template name
    arch_dict = yaml_load(arch_yaml_file)
    # arch_dict['architecture']['template'] = arch

    # override noc dimension
    if 'noc' in arch:
        arch_dict['architecture']['noc']['num_instances'] = [noc1, noc2]
        arch_dict['architecture']['noc']['instance']['noc']['bandwidth'] = nocbw

    # override sram brow config
    if conference == 'asplos2024':
        pass
    else:
        if i_byte_per_row != None:
            arch_dict['architecture']['isram']['byte_per_brow'] = i_byte_per_row
        if o_byte_per_row != None:
            arch_dict['architecture']['osram']['byte_per_brow'] = o_byte_per_row
        if w_byte_per_row != None:
            arch_dict['architecture']['wsram']['byte_per_brow'] = w_byte_per_row
        
        arch_dict['architecture']['isram']['bank_per_sram'] = bank_ct
        arch_dict['architecture']['osram']['bank_per_sram'] = bank_ct
        arch_dict['architecture']['wsram']['bank_per_sram'] = bank_ct

        if sram == 0.125:
            # print('half size')
            arch_dict['architecture']['isram']['byte_tot_sram'] /= 2
            arch_dict['architecture']['osram']['byte_tot_sram'] /= 2
            arch_dict['architecture']['wsram']['byte_tot_sram'] /= 2
        elif sram == 0.25: pass
        elif sram == 0.5: 
            # print('double size')
            arch_dict['architecture']['isram']['byte_tot_sram'] *= 2
            arch_dict['architecture']['osram']['byte_tot_sram'] *= 2
            arch_dict['architecture']['wsram']['byte_tot_sram'] *= 2
        elif sram == 1:
            # print('quadraple size')
            arch_dict['architecture']['isram']['byte_tot_sram'] *= 4
            arch_dict['architecture']['osram']['byte_tot_sram'] *= 4
            arch_dict['architecture']['wsram']['byte_tot_sram'] *= 4
        elif sram == 2:
            # print('8x size')
            arch_dict['architecture']['isram']['byte_tot_sram'] *= 8
            arch_dict['architecture']['osram']['byte_tot_sram'] *= 8
            arch_dict['architecture']['wsram']['byte_tot_sram'] *= 8
        else: assert False, print(f'{sram}')
    # override dram bw
    if dram_bw != None:
        # print(f'dram bw: {dram_bw}')
        arch_dict['architecture']['odram']['bandwidth'] = dram_bw
    else:
        assert False

    arch_dict['architecture']['technology'] = tech

    yaml_overwrite(arch_yaml_file, arch_dict)

    # === workload file ===
    # override cycle and batch
    wl_dict = yaml_load(wl_yaml_file)
    for layer_dict in wl_dict['workload'].values():        
        layer_dict['cycle'] = int(cycle)
        layer_dict['N'] = int(batch_size)
        if 'cris' in arch:
            # cris needs the similarity attribute
            layer_dict['similarity'] = 0.5
    yaml_overwrite(wl_yaml_file, wl_dict)
    
    # print(f'Done generating config in dir {dir_name}')

def gen_comb_config(arch:list, workload:list, dim1:list, dim2:list, format_:list, 
    dram:list, sram:list, batch_size:list, noc1:list, noc2:list, nocbw:list, conference:str, tech_override:int):
    list_ = [arch, workload, dim1, dim2, format_, dram, sram, batch_size, noc1, noc2, nocbw, [conference], [tech_override]]
    # print(list_)
    config = list(itertools.product(*list_))
    for tuple_i in config:
        gen_single_config(*tuple_i)

def get_workload_name(wl):
    wl_name = []
    for wl_i in wl:
        if 'bert' in wl_i: wl_name.append('bert_base_uncased')
        elif 'rnnt' in wl_i: wl_name.append('RNNT')
        elif 'ssd' in wl_i: wl_name.append( 'ssd300_vgg16')
        elif 'unet' in wl_i: wl_name.append( 'UNet')
        else: wl_name.append(wl_i)
    return wl_name 

def delete_dir(folder_path):
    if os.path.exists(folder_path):
        # checking whether the folder is empty or not
        if len(os.listdir(folder_path)) == 0:
            # removing the file using the os.remove() method
            os.rmdir(folder_path)
        else:
            # messaging saying folder not empty
            print("Folder is not empty. removing anyway")
            shutil.rmtree(folder_path)
    else:
        # file not found message
        print("File not found in the directory")
        

if __name__ == "__main__":
    parser = construct_argparser()
    args = parser.parse_args()

    for arch in args.arch:
        if 'noc' in arch:
            assert args.noc1 and args.noc2
        else: 
            args.noc1 = [1]
            args.noc2 = [1]
    
    assert args.batch_size != None, 'Specify batch size'

    wl_name_list = get_workload_name(args.workload)
    gen_comb_config(args.arch, wl_name_list, args.dim1, args.dim2, 
        args.format, args.dram, args.sram, args.batch_size, args.noc1, args.noc2, args.nocbw, args.conference, args.tech_override)

    