from collections import OrderedDict
from ctypes.wintypes import HACCEL
from utils import yaml_load, yaml_overwrite
import numpy as np

if __name__ == "__main__":
    wl_path = '/mnt/ssd1/zwp/repo/Archer/workload/RNNT/workload.yaml'
    wl_dict_ = yaml_load(wl_path)
    wl_dict = wl_dict_['workload']
    lstm_arr = [(26,250,2,1), (161,250,3,2)]

    i = 1
    for lstm in lstm_arr:
        for duplicate in np.arange(lstm[2]):
            H = lstm[1]
            B = lstm[0]
            D = lstm[3]

            #_1
            layer_dict_1 = OrderedDict()
            layer_name_1 = f'rnnt_lstm{i}_{duplicate}_1'
            layer_dict_1['cycle'] = 1
            layer_dict_1['type'] = 'conv' 
            layer_dict_1['Hdilation'] = 1 
            layer_dict_1['Wdilation'] = 1
            layer_dict_1['Hstride'] = 1
            layer_dict_1['Wstride'] = 1
            layer_dict_1['N'] = 1

            layer_dict_1['X'] = 1
            layer_dict_1['Y'] = B
            layer_dict_1['R'] = 1
            layer_dict_1['S'] = 1
            layer_dict_1['C'] = H
            layer_dict_1['K'] = H
            wl_dict[layer_name_1] = layer_dict_1

            #_2
            layer_dict_2 = OrderedDict()
            layer_name_2 = f'rnnt_lstm{i}_{duplicate}_2'
            layer_dict_2['cycle'] = 1
            layer_dict_2['type'] = 'conv' 
            layer_dict_2['Hdilation'] = 1 
            layer_dict_2['Wdilation'] = 1
            layer_dict_2['Hstride'] = 1
            layer_dict_2['Wstride'] = 1
            layer_dict_2['N'] = 1

            layer_dict_2['X'] = 1
            layer_dict_2['Y'] = B
            layer_dict_2['R'] = 1
            layer_dict_2['S'] = 1
            layer_dict_2['C'] = D
            layer_dict_2['K'] = H
            wl_dict[layer_name_2] = layer_dict_2

            #_3
            layer_dict_3 = OrderedDict()
            layer_name_3 = f'rnnt_lstm{i}_{duplicate}_3'
            layer_dict_3['cycle'] = 1
            layer_dict_3['type'] = 'conv' 
            layer_dict_3['Hdilation'] = 1 
            layer_dict_3['Wdilation'] = 1
            layer_dict_3['Hstride'] = 1
            layer_dict_3['Wstride'] = 1
            layer_dict_3['N'] = 1

            layer_dict_3['X'] = 1
            layer_dict_3['Y'] = B
            layer_dict_3['R'] = 1
            layer_dict_3['S'] = 1
            layer_dict_3['C'] = H
            layer_dict_3['K'] = H
            wl_dict[layer_name_3] = layer_dict_3

            #_4
            layer_dict_4 = OrderedDict()
            layer_name_4 = f'rnnt_lstm{i}_{duplicate}_4'
            layer_dict_4['cycle'] = 1
            layer_dict_4['type'] = 'conv' 
            layer_dict_4['Hdilation'] = 1 
            layer_dict_4['Wdilation'] = 1
            layer_dict_4['Hstride'] = 1
            layer_dict_4['Wstride'] = 1
            layer_dict_4['N'] = 1

            layer_dict_4['X'] = 1
            layer_dict_4['Y'] = B
            layer_dict_4['R'] = 1
            layer_dict_4['S'] = 1
            layer_dict_4['C'] = D
            layer_dict_4['K'] = H
            wl_dict[layer_name_4] = layer_dict_4

            #_5
            layer_dict_5 = OrderedDict()
            layer_name_5 = f'rnnt_lstm{i}_{duplicate}_5'
            layer_dict_5['cycle'] = 1
            layer_dict_5['type'] = 'conv' 
            layer_dict_5['Hdilation'] = 1 
            layer_dict_5['Wdilation'] = 1
            layer_dict_5['Hstride'] = 1
            layer_dict_5['Wstride'] = 1
            layer_dict_5['N'] = 1

            layer_dict_5['X'] = 1
            layer_dict_5['Y'] = B
            layer_dict_5['R'] = 1
            layer_dict_5['S'] = 1
            layer_dict_5['C'] = H
            layer_dict_5['K'] = H
            wl_dict[layer_name_5] = layer_dict_5

            #_6
            layer_dict_6 = OrderedDict()
            layer_name_6 = f'rnnt_lstm{i}_{duplicate}_6'
            layer_dict_6['cycle'] = 1
            layer_dict_6['type'] = 'conv' 
            layer_dict_6['Hdilation'] = 1 
            layer_dict_6['Wdilation'] = 1
            layer_dict_6['Hstride'] = 1
            layer_dict_6['Wstride'] = 1
            layer_dict_6['N'] = 1

            layer_dict_6['X'] = 1
            layer_dict_6['Y'] = B
            layer_dict_6['R'] = 1
            layer_dict_6['S'] = 1
            layer_dict_6['C'] = D
            layer_dict_6['K'] = H
            wl_dict[layer_name_6] = layer_dict_6

            #_7
            layer_dict_7 = OrderedDict()
            layer_name_7 = f'rnnt_lstm{i}_{duplicate}_7'
            layer_dict_7['cycle'] = 1
            layer_dict_7['type'] = 'conv' 
            layer_dict_7['Hdilation'] = 1 
            layer_dict_7['Wdilation'] = 1
            layer_dict_7['Hstride'] = 1
            layer_dict_7['Wstride'] = 1
            layer_dict_7['N'] = 1

            layer_dict_7['X'] = 1
            layer_dict_7['Y'] = B
            layer_dict_7['R'] = 1
            layer_dict_7['S'] = 1
            layer_dict_7['C'] = H
            layer_dict_7['K'] = H
            wl_dict[layer_name_7] = layer_dict_7

            #_8
            layer_dict_8 = OrderedDict()
            layer_name_8 = f'rnnt_lstm{i}_{duplicate}_8'
            layer_dict_8['cycle'] = 1
            layer_dict_8['type'] = 'conv' 
            layer_dict_8['Hdilation'] = 1 
            layer_dict_8['Wdilation'] = 1
            layer_dict_8['Hstride'] = 1
            layer_dict_8['Wstride'] = 1
            layer_dict_8['N'] = 1

            layer_dict_8['X'] = 1
            layer_dict_8['Y'] = B
            layer_dict_8['R'] = 1
            layer_dict_8['S'] = 1
            layer_dict_8['C'] = D
            layer_dict_8['K'] = H
            wl_dict[layer_name_8] = layer_dict_8

            #_9
            layer_dict_9 = OrderedDict()
            layer_name_9 = f'rnnt_lstm{i}_{duplicate}_9'
            layer_dict_9['cycle'] = 1
            layer_dict_9['type'] = 'conv' 
            layer_dict_9['Hdilation'] = 1 
            layer_dict_9['Wdilation'] = 1
            layer_dict_9['Hstride'] = 1
            layer_dict_9['Wstride'] = 1
            layer_dict_9['N'] = 1

            layer_dict_9['X'] = 1
            layer_dict_9['Y'] = B
            layer_dict_9['R'] = 1
            layer_dict_9['S'] = B
            layer_dict_9['C'] = H
            layer_dict_9['K'] = 1
            wl_dict[layer_name_9] = layer_dict_9

            #_10
            layer_dict_10 = OrderedDict()
            layer_name_10 = f'rnnt_lstm{i}_{duplicate}_10'
            layer_dict_10['cycle'] = 1
            layer_dict_10['type'] = 'conv' 
            layer_dict_10['Hdilation'] = 1 
            layer_dict_10['Wdilation'] = 1
            layer_dict_10['Hstride'] = 1
            layer_dict_10['Wstride'] = 1
            layer_dict_10['N'] = 1

            layer_dict_10['X'] = 1
            layer_dict_10['Y'] = B
            layer_dict_10['R'] = 1
            layer_dict_10['S'] = B
            layer_dict_10['C'] = H
            layer_dict_10['K'] = 1
            wl_dict[layer_name_10] = layer_dict_10

            #_11
            layer_dict_11 = OrderedDict()
            layer_name_11 = f'rnnt_lstm{i}_{duplicate}_11'
            layer_dict_11['cycle'] = 1
            layer_dict_11['type'] = 'conv' 
            layer_dict_11['Hdilation'] = 1 
            layer_dict_11['Wdilation'] = 1
            layer_dict_11['Hstride'] = 1
            layer_dict_11['Wstride'] = 1
            layer_dict_11['N'] = 1

            layer_dict_11['X'] = 1
            layer_dict_11['Y'] = B
            layer_dict_11['R'] = 1
            layer_dict_11['S'] = B
            layer_dict_11['C'] = H
            layer_dict_11['K'] = 1
            wl_dict[layer_name_11] = layer_dict_11
            
        i += 1
    yaml_overwrite(wl_path, wl_dict_)
    print(f'Done appending lstm to {wl_path}')