from collections import OrderedDict
import numpy as np
from workload import conv
import math
from utils import bcolors, round2power
import copy


def dataflow(
    arch_cfg: OrderedDict,
    prob_cfg: OrderedDict
):
    """
    This function implements the carat architecture with intra-image data schedule.
    The height dimension contains PQK, while the width dimension contains NPQ
    All memories are double buffered.
    
    Tensor format:

    Tensorflow-preferred format: (systolic arry)
    Input format: NXYC
    Weight format: RSCK
    Output format: NPQK

    Please refer to https://inst.eecs.berkeley.edu/~eecs151/sp20/files/lec18-DNN-QijingHuang.pdf page 29 for how to schedule GEMM.
    """

    # intra-image parallel with weight staionary

    compu_cfg = arch_cfg["compu"]
    ififo_cfg = arch_cfg["ififo"]
    wfifo_cfg = arch_cfg["wfifo"]
    ofifo_cfg = arch_cfg["ofifo"]

    num_rows = compu_cfg["num_instances"][0]
    num_cols = compu_cfg["num_instances"][1]

    # output shall include the ordered output of all components
    output = OrderedDict()

    # initialize the problem
    prob = conv("", prob_cfg)

    # fanout of vertical weights
    fanout = 8

    # analyze tile size based on buffer size
    # npq is the inner most tile, and we use this as a start
    # 1.1 calculate npq tile size based on ofifo size
    assert ofifo_cfg["num_instances"][1] == 2, \
        bcolors.FAIL + "The output fifo is not double buffered." + bcolors.ENDC
    assert ofifo_cfg["num_instances"][0] == num_rows, \
        bcolors.FAIL + "The output fifo count is unequal to the array row count." + bcolors.ENDC
    # max output can be hold for this tile due to limited ofifo depth
    ts_o_bd = ofifo_cfg["instance"]["queue"]["depth"]

    # 1.2 calculate npq tile size based on ififo size
    assert ififo_cfg["num_instances"][1] == 2, \
        bcolors.FAIL + "The input fifo is not double buffered." + bcolors.ENDC
    assert ififo_cfg["num_instances"][0] == num_rows, \
        bcolors.FAIL + "The input fifo count is unequal to the array row count." + bcolors.ENDC
    # max output can be hold for this tile due to limited ififo depth
    ts_i_bd = ififo_cfg["instance"]["queue"]["depth"]
    assert ts_i_bd >= (num_rows / fanout / prob.cycle), \
        bcolors.FAIL + "The input fifo depth is not enough to hide the compute latency." + bcolors.ENDC

    # 1.3 calculate crs tile size based on wfifo size
    assert wfifo_cfg["num_instances"][0] == 2, \
        bcolors.FAIL + "The weight fifo is not double buffered." + bcolors.ENDC
    assert wfifo_cfg["num_instances"][1] == num_cols, \
        bcolors.FAIL + "The weight fifo count is unequal to the array column count." + bcolors.ENDC
    # max output can be hold for this tile due to limited wfifo depth
    ts_w_bd = wfifo_cfg["instance"]["queue"]["depth"]

    osram_cfg = arch_cfg["osram"]
    odram_cfg = arch_cfg["odram"]

    # dram bandwidth in byte per cycle
    odram_bw = arch_cfg["odram"]["bandwidth"] * 10**9 / (arch_cfg["frequency"] * 10**6)

    # osram traffic total, output is stationary is the accumulation fifo, so transfered once
    # with odram, compute
    osram_word_rd, osram_word_wr = 0, 0
    # osram_bw = osram_cfg["byte_per_brow"] * osram_cfg["bank_per_sram"]
    osram_bw = osram_cfg["byte_per_word"] * num_rows / prob_cfg["cycle"]


    # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 
    # prioritize K in height
    # n is mapped to width
    # k is mapped to height
    # pq are split to both
    # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 

    # in this schedule, we assume isram and wsram are identical, same data width, same total size
    isram_cfg = arch_cfg["isram"]
    wsram_cfg = arch_cfg["wsram"]
    
    # sram traffic
    # with compute, odram
    isram_word_rd, isram_word_wr = 0, 0
    # isram_bw = isram_cfg["byte_per_brow"] * isram_cfg["bank_per_sram"]
    isram_bw = isram_cfg["byte_per_word"] * num_rows / prob_cfg["cycle"]

    # wsram traffic total
    # with odram, compute
    wsram_word_rd, wsram_word_wr = 0, 0
    # wsram_bw = wsram_cfg["byte_per_brow"] * wsram_cfg["bank_per_sram"]
    wsram_bw = wsram_cfg["byte_per_word"] * num_cols / prob_cfg["cycle"]
    
    compute_cycle = 0
    stream_in_cycle = 0
    stream_out_cycle = 0

    # assume only sram stall, i.e., fifos do not stall due to double buffer
    isram_stall_rd, isram_stall_wr = 0, 0
    wsram_stall_rd, wsram_stall_wr = 0, 0
    osram_stall_rd, osram_stall_wr = 0, 0
    odram_stall = 0

    first_map_cols, last_map_rows = 0, 0
    first_ififo_map, first_wfifo_map = 0, 0
    first_ififo_flag, first_wfifo_flag = True, True

    # search k_max
    isram_word = isram_cfg["byte_tot_sram"] / isram_cfg["byte_per_word"] / 2
    wsram_word = wsram_cfg["byte_tot_sram"] / wsram_cfg["byte_per_word"] / 2
    osram_word = osram_cfg["byte_tot_sram"] / osram_cfg["byte_per_word"] / 2
    
    total_util_max = 0

    n_max_util = prob.N
    k_max_util = prob.K
    c_max_util = prob.C
    p_max_util = prob.P
    q_max_util = prob.Q
    x_max_util = prob.X
    y_max_util = prob.Y

    k2row = True

    success_flag = False

    # this tiling prioritizes the compute utilization, and it also checks memory utilization
    # prefer to large K and N, and small C
    k_appr = round2power(prob.K, mode='ceil')
    k_loops = int(math.ceil(math.log(prob.K, 2))) + 1
    for k_idx in range(k_loops):
        c_appr_max = math.floor(wsram_cfg["byte_tot_sram"] / wsram_cfg["byte_per_word"] / 2 / k_appr / (prob.R * prob.S))
        c_appr_max = round2power(min(c_appr_max, prob.C))
        c_appr = 1
        c_loops = int(math.ceil(math.log(c_appr_max / c_appr, 2))) + 1
        for c_idx in range(c_loops):
            n_appr = round2power(prob.N)
            n_loops = int(math.ceil(math.log(prob.N, 2))) + 1
            for n_idx in range(n_loops):
                pq_appr = osram_cfg["byte_tot_sram"] / osram_cfg["byte_per_word"] / 2 / n_appr / k_appr
                if pq_appr < 1:
                    continue
                pq_appr = round2power(min(pq_appr, prob.P * prob.Q))
                # assume pq is square
                p_appr = round2power(min(math.sqrt(pq_appr), prob.P))
                q_appr = round2power(min(pq_appr / p_appr, prob.Q))

                xy_appr = isram_cfg["byte_tot_sram"] / isram_cfg["byte_per_word"] / 2 / n_appr / c_appr
                if xy_appr < 1:
                    continue
                xy_appr = min(xy_appr, prob.X * prob.Y)
                
                # array utilization
                # n map to row
                n2row_row_ratio = (prob.N / 2**n_idx) / num_rows
                n2row_row_ratio_list = [n2row_row_ratio * 2**x for x in range(int(math.log(pq_appr, 2)) + 1)]
                n2row_row_ratio_list = [0 if x > 1 else x for x in n2row_row_ratio_list]
                n2row_pq_factor = min(2**np.argmax(n2row_row_ratio_list), p_appr * q_appr)
                
                n2row_row_util = n2row_row_ratio * n2row_pq_factor
                n2row_col_ratio = (prob.K / 2**k_idx) * p_appr * q_appr / n2row_pq_factor / num_cols
                n2row_col_util = n2row_col_ratio / math.ceil(n2row_col_ratio)
                n2row_map_util = n2row_row_util * n2row_col_util

                # k map to row
                k2row_row_ratio = (prob.K / 2**k_idx) / num_rows
                k2row_row_ratio_list = [k2row_row_ratio * 2**x for x in range(int(math.log(pq_appr, 2)) + 1)]
                k2row_row_ratio_list = [0 if x > 1 else x for x in k2row_row_ratio_list]
                k2row_pq_factor = min(2**np.argmax(k2row_row_ratio_list), p_appr * q_appr)

                k2row_row_util = k2row_row_ratio * k2row_pq_factor
                k2row_col_ratio = (prob.N / 2**n_idx) * p_appr * q_appr / k2row_pq_factor / num_cols
                k2row_col_util = k2row_col_ratio / math.ceil(k2row_col_ratio)
                k2row_map_util = k2row_row_util * k2row_col_util
                
                array_util = max(n2row_map_util, k2row_map_util)

                # memory utilization
                nxyc_appr = n_appr * xy_appr * c_appr
                rsck_appr = prob.R * prob.S * c_appr * k_appr
                npqk_appr = n_appr * pq_appr * k_appr

                isram_util = nxyc_appr / isram_word
                wsram_util = rsck_appr / wsram_word
                osram_util = npqk_appr / osram_word

                # isram should hold at least one XY
                isram_util_bad = (isram_util == 0)
                wsram_util_bad = (wsram_util > 1) or (wsram_util == 0)
                osram_util_bad = (osram_util > 1) or (osram_util == 0)
                xsram_util_bad = isram_util_bad + wsram_util_bad or osram_util_bad
                xsram_util = (isram_util + wsram_util + osram_util) / 3

                # weighted total utilization
                total_util = array_util + xsram_util * 0.05

                if (total_util > total_util_max) and (not xsram_util_bad):
                    # print(pq_appr, p_appr, q_appr, n2row_row_ratio, k2row_row_ratio)
                    # print(n2row_row_ratio_list, k2row_row_ratio_list, n2row_pq_factor, k2row_pq_factor)
                    success_flag = True
                    n_max_util = n_appr
                    k_max_util = k_appr
                    c_max_util = c_appr
                    p_max_util = p_appr
                    q_max_util = q_appr
                    x_max_util = (p_appr - 1) * prob.Hstride + prob.R
                    y_max_util = (q_appr - 1) * prob.Wstride + prob.S
                    array_util_max_util = array_util
                    k2row = (k2row_map_util > n2row_map_util)
                    if k2row:
                        row_pq_factor = k2row_pq_factor
                    else:
                        row_pq_factor = n2row_pq_factor
                    total_util_max = total_util
                    # print(k2row, total_util, array_util, n2row_map_util, k2row_map_util)
                    # print()
                
                n_appr = n_appr / 2
            c_appr = c_appr * 2
        k_appr = k_appr / 2

    assert success_flag
    # print("k2row: ", k2row, " row_pq_factor: ", row_pq_factor, " array_util: ", array_util_max_util, " npqk: (", n_max_util, p_max_util, q_max_util, k_max_util, ")", " c: ", c_max_util, " row-col: (", num_rows, num_cols, ")")

    k_max = k_max_util
    k_map = np.divmod(prob.K, k_max)
    for k in range(int(k_map[0] + 1)):
        k_this = k_max if k < k_map[0] else k_map[1]
        if k_this == 0:
            break
        
        n_this = prob.N
        p_max = p_max_util
        q_max = q_max_util
        npq_max = n_this * p_max * q_max
        npq_map = np.divmod(prob.N * prob.P * prob.Q, npq_max)
        for npq in range(int(npq_map[0] + 1)):
            npq_this = npq_max if npq < npq_map[0] else npq_map[1]
            if npq_this == 0:
                break
            
            pq_this = min(npq_this / n_this, prob.P * prob.Q)
            # assume pq is square
            p_this = min(round2power(math.sqrt(pq_this)), prob.P)
            q_this = min(pq_this / p_this, prob.Q)
            row_pq_factor_this = min(row_pq_factor, p_this * q_this)
            osram_this = n_this * k_this * pq_this
            # stream out cycle is per n tile, i.e., n_max_util
            stream_out_cycle = (n_max_util * p_this * q_this * k_this) * osram_cfg["byte_per_word"] / min(osram_bw, odram_bw)

            x_this = (p_this - 1) * prob.Hstride + prob.R
            y_this = (q_this - 1) * prob.Wstride + prob.S

            odram_byte_iw = 0
            compute_cycle_c_delta = 0
            
            c_max = c_max_util
            c_map = np.divmod(prob.C, c_max)
            for c in range(int(c_map[0] + 1)):
                c_this = c_max if c < c_map[0] else c_map[1]
                if c_this == 0:
                    break
                # print(k2row, row_pq_factor_this, n_max_util, p_this, q_this, k_this, c_this)
                # each output element need this many cycles
                crs_array_folds = c_this * prob.R * prob.S
                if k2row:
                    row_mem_access_width = num_rows / prob.cycle
                    if row_mem_access_width >= k_this:
                        under_util_iosram_bw = math.ceil(row_mem_access_width / k_this)
                    else:
                        row_mapping = k_this / (row_mem_access_width)
                        under_util_iosram_bw = math.ceil(row_mapping) / math.floor(row_mapping)
                    under_util_wsram_bw = math.ceil(num_cols / prob.cycle / n_max_util)

                    # the first term captures the row mapping, the second term captures the col mapping
                    # if row mapping is exactly multiple of cols, 1 cycle per output.
                    total_col = (n_max_util * pq_this / row_pq_factor_this)
                    total_col_mapping = total_col / num_cols
                    if math.floor(total_col_mapping) == 0:
                        cycle = num_cols / math.ceil(total_col)
                    else:
                        cycle = math.ceil(total_col_mapping) / math.floor(total_col_mapping)
                    compute_cycle_c = math.ceil(osram_this * crs_array_folds / (k_this * row_pq_factor_this)) * cycle * math.ceil(k_this / num_rows)

                    # print(compute_cycle_c, osram_this, osram_this * crs_array_folds, (k_this * row_pq_factor_this), math.ceil(osram_this * crs_array_folds / (k_this * row_pq_factor_this)), math.ceil(num_cols / (n_max_util * pq_this / row_pq_factor_this)))
                    # print(npq_this, n_max_util, pq_this, row_pq_factor_this, cycle)
                    # print()
                else:
                    row_mem_access_width = num_rows / prob.cycle
                    if row_mem_access_width >= n_max_util:
                        under_util_iosram_bw = math.ceil(row_mem_access_width / n_max_util)
                    else:
                        row_mapping = n_max_util / (row_mem_access_width)
                        under_util_iosram_bw = math.ceil(row_mapping) / math.floor(row_mapping)
                    under_util_wsram_bw = math.ceil(num_cols / prob.cycle / k_this)
                    total_col = k_this * pq_this / row_pq_factor_this
                    total_col_mapping = total_col / num_cols
                    if math.floor(total_col_mapping) == 0:
                        cycle = num_cols / math.ceil(total_col)
                    else:
                        cycle = math.ceil(total_col_mapping) / math.floor(total_col_mapping)
                    compute_cycle_c = math.ceil(osram_this * crs_array_folds / (n_max_util * row_pq_factor_this)) * cycle * math.ceil(n_max_util / num_rows)

                    # print(compute_cycle_c, osram_this, osram_this * crs_array_folds, (n_max_util * row_pq_factor_this), math.ceil(osram_this * crs_array_folds / (n_max_util * row_pq_factor_this)), math.ceil(num_cols / (k_this * pq_this / row_pq_factor_this)))
                compute_cycle += compute_cycle_c
                compute_cycle_c_delta += compute_cycle_c

                # for every c tile, outputs are both loaded and read once by the array/ofifo
                # no stall will happen for this
                osram_word_wr += osram_this
                osram_word_rd += osram_this

                # for every c tile, weights are loaded only once
                wsram_this = c_this * (prob.R * prob.S) * k_this
                wsram_word_wr += wsram_this
                wsram_word_rd += wsram_this

                # no read stall
                wsram_bytes = wsram_this * wsram_cfg["byte_per_word"]
                wsram_stall_wr += math.ceil(max(wsram_bytes / (wsram_bw / under_util_wsram_bw) - compute_cycle_c, 0))

                isram_this = n_this * x_this * y_this * c_this
                isram_word_wr += isram_this
                isram_word_rd += isram_this
                
                if first_wfifo_flag:
                    stream_in_cycle =   isram_this * isram_cfg["byte_per_word"] / min((isram_bw / under_util_iosram_bw), odram_bw) + \
                                        wsram_this * wsram_cfg["byte_per_word"] / min((wsram_bw / under_util_wsram_bw), odram_bw)
                    first_wfifo_flag = False

                # no read stall
                isram_bytes = isram_this * isram_cfg["byte_per_word"]
                isram_stall_wr += math.ceil(max(isram_bytes / (isram_bw / under_util_iosram_bw) - compute_cycle_c, 0))

                odram_byte_iw += wsram_bytes + isram_bytes

            # rd stall, odram read osram
            # no wr stall
            osram_stall_rd += math.ceil(max(osram_this * osram_cfg["byte_per_word"] / (osram_bw / under_util_iosram_bw) - compute_cycle_c_delta, 0))

            odram_stall += math.ceil(max((osram_this + odram_byte_iw) / odram_bw - compute_cycle_c_delta, 0))

    first_map_cols = num_rows / fanout
    first_ififo_map = num_rows * ts_i_bd / isram_cfg["byte_per_word"] / isram_bw
    first_wfifo_map = ts_w_bd * num_cols / wsram_cfg["byte_per_word"] / wsram_bw

    # count for initial streaming latency
    compute_cycle += first_map_cols + last_map_rows
    compute_cycle += first_ififo_map + first_wfifo_map
    compute_cycle += stream_in_cycle + stream_out_cycle

    utilization = prob.flops / (num_rows * 2 / (prob.cycle / num_cols) * compute_cycle) * 100.

    compu_result = OrderedDict()
    compu_result["cycle"] = int(compute_cycle)
    compu_result["utilization"] = float(utilization)
    compu_result["flops"] = float(prob.flops)

    ififo_result = copy.deepcopy(compu_result)
    wfifo_result = copy.deepcopy(compu_result)
    ofifo_result = copy.deepcopy(compu_result)

    isram_result = OrderedDict()
    isram_result["access stall"]    = OrderedDict({ "rd": int(isram_stall_rd), 
                                                    "wr": int(isram_stall_wr)})
    isram_result["byte total"]      = OrderedDict({ "rd": int(math.ceil(isram_word_rd * isram_cfg["byte_per_word"])), 
                                                    "wr": int(math.ceil(isram_word_wr * isram_cfg["byte_per_word"]))})
    isram_result["access total"]    = OrderedDict({ "rd": int(math.ceil(isram_word_rd * isram_cfg["byte_per_word"] / isram_cfg["byte_per_brow"])), 
                                                    "wr": int(math.ceil(isram_word_wr * isram_cfg["byte_per_word"] / isram_cfg["byte_per_brow"]))})
    
    wsram_result = OrderedDict()
    wsram_result["access stall"]    = OrderedDict({ "rd": int(wsram_stall_rd), 
                                                    "wr": int(wsram_stall_wr)})
    wsram_result["byte total"]      = OrderedDict({ "rd": int(math.ceil(wsram_word_rd * wsram_cfg["byte_per_word"])), 
                                                    "wr": int(math.ceil(wsram_word_wr * wsram_cfg["byte_per_word"]))})
    wsram_result["access total"]    = OrderedDict({ "rd": int(math.ceil(wsram_word_rd * wsram_cfg["byte_per_word"] / wsram_cfg["byte_per_brow"])), 
                                                    "wr": int(math.ceil(wsram_word_wr * wsram_cfg["byte_per_word"] / wsram_cfg["byte_per_brow"]))})

    osram_result = OrderedDict()
    osram_result["access stall"]    = OrderedDict({ "rd": int(osram_stall_rd), 
                                                    "wr": int(osram_stall_wr)})
    osram_result["byte total"]      = OrderedDict({ "rd": int(math.ceil(osram_word_rd * osram_cfg["byte_per_word"])), 
                                                    "wr": int(math.ceil(osram_word_wr * osram_cfg["byte_per_word"]))})
    osram_result["access total"]    = OrderedDict({ "rd": int(math.ceil(osram_word_rd * osram_cfg["byte_per_word"] / osram_cfg["byte_per_brow"])), 
                                                    "wr": int(math.ceil(osram_word_wr * osram_cfg["byte_per_word"] / osram_cfg["byte_per_brow"]))})

    odram_result = OrderedDict()
    odram_byte_rd = isram_word_wr * isram_cfg["byte_per_word"] + wsram_word_wr * wsram_cfg["byte_per_word"]
    odram_byte_wr = prob.N * prob.P * prob.Q * prob.K * osram_cfg["byte_per_word"]
    odram_result["access stall"]    = OrderedDict({ "rd": 0 if odram_byte_rd == 0 else int(math.ceil(odram_stall * odram_byte_rd / (odram_byte_rd + odram_byte_wr))),
                                                    "wr": 0 if odram_byte_wr == 0 else int(math.ceil(odram_stall * odram_byte_wr / (odram_byte_rd + odram_byte_wr)))})
    odram_result["byte total"]      = OrderedDict({ "rd": int(odram_byte_rd), 
                                                    "wr": int(odram_byte_wr)})
    odram_result["access total"]    = OrderedDict({ "rd": int(math.ceil(odram_byte_rd / odram_cfg["bits_per_chip"])), 
                                                    "wr": int(math.ceil(odram_byte_wr / odram_cfg["bits_per_chip"]))})
    odram_result["page total"]      = OrderedDict({ "rd": int(math.ceil(odram_byte_rd * 8 / odram_cfg["bits_per_page"])), 
                                                    "wr": int(math.ceil(odram_byte_wr * 8 / odram_cfg["bits_per_page"]))})

    output["compu"] = compu_result
    output["ififo"] = ififo_result
    output["wfifo"] = wfifo_result
    output["ofifo"] = ofifo_result
    output["oaccu"] = ofifo_result
    output["waccu"] = ofifo_result
    output["itemp"] = ofifo_result
    output["osmux"] = ofifo_result
    output["bpipe"] = ofifo_result
    output["isram"] = isram_result
    output["wsram"] = wsram_result
    output["osram"] = osram_result
    output["odram"] = odram_result

    return output

