from collections import OrderedDict
import numpy as np
from workload import conv
import math
from utils import bcolors, round2power
import copy


def dataflow(
    arch_cfg: OrderedDict,
    prob_cfg: OrderedDict
):
    """
    This function implements the weight stationary systolic array with multiple nodes.
    All memories are double buffered.
    
    Tensor format:
    
    Tensorflow-preferred format: (systolic arry)
    Input format: NXYC
    Weight format: RSCK
    Output format: NPQK

    Please refer to https://inst.eecs.berkeley.edu/~eecs151/sp20/files/lec18-DNN-QijingHuang.pdf page 29 for how to schedule GEMM.
    """

    # intra-image parallel with weight staionary

    compu_cfg = arch_cfg["compu"]
    ififo_cfg = arch_cfg["ififo"]
    wfifo_cfg = arch_cfg["wfifo"]
    ofifo_cfg = arch_cfg["ofifo"]
    isram_cfg = arch_cfg["isram"]
    wsram_cfg = arch_cfg["wsram"]
    osram_cfg = arch_cfg["osram"]
    odram_cfg = arch_cfg["odram"]

    num_rows = compu_cfg["num_instances"][0]
    num_cols = compu_cfg["num_instances"][1]

    # configuration of NoC
    # the rows and cols define the repetition of single node carat
    # it also determines the split of input, weight and output
    noc_cfg = arch_cfg["noc"]
    noc_bw= noc_cfg["instance"]["noc"]["bandwidth"] * 10**9 / (arch_cfg["frequency"] * 10**6)
    noc_rows = noc_cfg["num_instances"][0]
    noc_cols = noc_cfg["num_instances"][1]
    num_nodes = noc_rows * noc_cols

    # output shall include the ordered output of all components
    output = OrderedDict()

    # initialize the problem
    prob = conv("", prob_cfg)

    # dram bandwidth in byte per cycle
    odram_bw = arch_cfg["odram"]["bandwidth"] * 10**9 / (arch_cfg["frequency"] * 10**6)

    # sram traffic
    # with compute, odram
    isram_word_rd, isram_word_wr = 0, 0
    isram_bw = isram_cfg["byte_per_brow"] * isram_cfg["bank_per_sram"]

    # wsram traffic total
    # with odram, compute
    wsram_word_rd, wsram_word_wr = 0, 0
    wsram_bw = wsram_cfg["byte_per_brow"] * wsram_cfg["bank_per_sram"]
    
    # osram traffic total, output is stationary is the accumulation fifo, so transfered once
    # with odram, compute
    osram_word_rd, osram_word_wr = 0, 0
    osram_bw = osram_cfg["byte_per_brow"] * osram_cfg["bank_per_sram"]

    # analyze tile size based on buffer size
    # 1.1 calculate npq tile size based on ofifo size
    assert ofifo_cfg["num_instances"][0] == 2, \
        bcolors.FAIL + "The output fifo is not double buffered." + bcolors.ENDC
    assert ofifo_cfg["num_instances"][1] == num_cols, \
        bcolors.FAIL + "The output fifo count is unequal to the array column count." + bcolors.ENDC
    # max output can be hold for this tile due to limited ofifo depth
    ts_o_bd = ofifo_cfg["instance"]["queue"]["depth"]

    # 1.2 calculate npq tile size based on ififo size
    assert ififo_cfg["num_instances"][1] == 2, \
        bcolors.FAIL + "The input fifo is not double buffered." + bcolors.ENDC
    assert ififo_cfg["num_instances"][0] == num_rows, \
        bcolors.FAIL + "The input fifo count is unequal to the array row count." + bcolors.ENDC
    # max output can be hold for this tile due to limited ififo depth
    ts_i_bd = ififo_cfg["instance"]["queue"]["depth"]

    # 1.3 calculate crs tile size based on wfifo size
    assert wfifo_cfg["num_instances"][0] == 2, \
        bcolors.FAIL + "The weight fifo is not double buffered." + bcolors.ENDC
    assert wfifo_cfg["num_instances"][1] == num_cols, \
        bcolors.FAIL + "The weight fifo count is unequal to the array column count." + bcolors.ENDC
    # max output can be hold for this tile due to limited wfifo depth
    ts_w_bd = wfifo_cfg["instance"]["queue"]["depth"]

    # This is the actual tile size
    ts_compute = ts_o_bd

    compute_cycle = 0
    stall_cycle = 0
    stream_in_cycle = 0
    stream_out_cycle = 0

    # assume only sram stall, i.e., fifos do not stall due to double buffer
    isram_stall_rd, isram_stall_wr = 0, 0
    isram_flag_tile_first_wr = True

    wsram_stall_rd, wsram_stall_wr = 0, 0
    wsram_flag_tile_first_wr = True

    osram_stall_rd, osram_stall_wr = 0, 0

    odram_stall = 0

    first_map_cols = 0
    first_map_flag = True
    last_map_rows = 0

    first_ififo_map, first_wfifo_map = 0, 0
    first_ififo_flag, first_wfifo_flag = True, True


    # is fc 
    fc_flag = (prob.R == 1) & (prob.S == 1) & (prob.P == 1) & (prob.Q == 1)
    
    # adopted loop tiling
    # one potential optimization is to ensure that k and c are close to each other for the best sram utilization

    # tile npqk among noc nodes
    if prob.N == 1:
        N_split = 1
        K_split = noc_rows * noc_cols
        prob_N_node = 1
        prob_K_node = prob.K / K_split
    else:
        N_split = min(noc_rows, noc_cols)
        K_split = max(noc_rows, noc_cols)
        prob_N_node = prob.N / N_split
        prob_K_node = prob.K / K_split

    noc_bytes = 0
    # k_map -> num_cols

    # search k_max
    n_appr = prob_N_node
    k_appr = min(num_cols, prob_K_node)
    while True:
        k_appr = round2power(k_appr)
        npq_appr = min(math.floor(osram_cfg["byte_tot_sram"] / osram_cfg["byte_per_word"] / 2 / k_appr), prob_N_node * prob.P * prob.Q)
        pq_appr = npq_appr / n_appr
        pq_appr = round2power(pq_appr)

        p_appr = math.sqrt(pq_appr)
        p_appr = round2power(min(p_appr, prob.P))
        q_appr = pq_appr / p_appr
        q_appr = round2power(min(q_appr, prob.Q))

        x_appr = (p_appr - 1) * prob.Hstride + prob.R
        y_appr = (q_appr - 1) * prob.Wstride + prob.S
        c_appr = min(math.floor(wsram_cfg["byte_tot_sram"] / wsram_cfg["byte_per_word"] / 2 / k_appr / (prob.R * prob.S)), prob.C)

        nxyc_appr = n_appr * x_appr * y_appr * c_appr

        if (nxyc_appr < isram_cfg["byte_tot_sram"] / isram_cfg["byte_per_word"] / 2) or (k_appr >= prob_K_node):
            break
        else:
            k_appr *= 2


    k_max = min(k_appr, prob_K_node)
    k_max = round2power(k_max)
    k_map = np.divmod(prob_K_node, k_max)
    for k in range(int(k_map[0] + 1)):
        k_this = k_max if k < k_map[0] else k_map[1]
        if k_this == 0:
            break
        
        npq_appr = min(math.floor(osram_cfg["byte_tot_sram"] / osram_cfg["byte_per_word"] / 2 / k_this), prob_N_node * prob.P * prob.Q)
        # n_appr = min(npq_appr if fc_flag else 1, prob_N_node)
        # n_this = round2power(n_appr)
        n_this = prob_N_node
        
        pq_appr = npq_appr / n_this
        pq_appr = round2power(pq_appr)

        p_appr = math.sqrt(pq_appr)
        p_appr = round2power(min(p_appr, prob.P))
        q_appr = pq_appr / p_appr
        q_appr = round2power(min(q_appr, prob.Q))

        npq_max = n_this * p_appr * q_appr
        npq_map = np.divmod(prob_N_node * prob.P * prob.Q, npq_max)
        for npq in range(int(npq_map[0] + 1)):
            npq_this = npq_max if npq < npq_map[0] else npq_map[1]
            if npq_this == 0:
                break
            
            osram_this = npq_this * k_this
            stream_out_cycle = osram_this * osram_cfg["byte_per_word"] / min(osram_bw, odram_bw)

            odram_byte_iw = 0
            compute_cycle_c_delta = 0
            # r and s are not tiled
            c_max = min(math.floor(wsram_cfg["byte_tot_sram"] / wsram_cfg["byte_per_word"] / 2 / k_this / (prob.R * prob.S)), prob.C)
            c_max = round2power(c_max)
            c_map = np.divmod(prob.C, c_max)
            for c in range(int(c_map[0] + 1)):
                c_this = c_max if c < c_map[0] else c_map[1]
                if c_this == 0:
                    break
                crs_array_folds = math.ceil(c_this * prob.R * prob.S / num_rows)
                compute_cycle_c = npq_this * crs_array_folds * prob_cfg["cycle"] * math.ceil(k_this / num_cols)
                compute_cycle += compute_cycle_c
                compute_cycle_c_delta += compute_cycle_c

                # for every c tile, outputs are both loaded and read once by the array/ofifo
                # no stall will happen for this
                osram_word_wr += osram_this
                osram_word_rd += osram_this

                # for every c tile, weights are loaded only once
                wsram_this = c_this * (prob.R * prob.S) * k_this
                wsram_word_wr += wsram_this
                wsram_word_rd += wsram_this

                # no read stall
                wsram_bytes = wsram_this * wsram_cfg["byte_per_word"]
                wsram_stall_wr += math.ceil(max(wsram_bytes / min(wsram_bw, noc_bw / K_split) - compute_cycle_c, 0))
                noc_bytes += wsram_bytes * K_split

                xy_appr = math.floor(isram_cfg["byte_tot_sram"] / isram_cfg["byte_per_word"] / 2 / c_this / n_this)
                x_appr = (p_appr - 1) * prob.Hstride + prob.R
                y_appr = (q_appr - 1) * prob.Wstride + prob.S
                isram_ratio = x_appr * y_appr / xy_appr

                isram_this = c_this * xy_appr * n_this * isram_ratio
                isram_word_wr += isram_this
                isram_word_rd += isram_this
                
                if first_wfifo_flag:
                    stream_in_cycle =   isram_this * isram_cfg["byte_per_word"] / min(isram_bw, odram_bw) + \
                                        wsram_this * wsram_cfg["byte_per_word"] / min(wsram_bw, odram_bw)
                    first_wfifo_flag = False

                # no read stall
                isram_bytes = isram_this * isram_cfg["byte_per_word"]
                isram_stall_wr += math.ceil(max(isram_bytes / min(isram_bw, noc_bw / N_split) - compute_cycle_c, 0))
                noc_bytes += isram_bytes * N_split

                odram_byte_iw += wsram_bytes + isram_bytes

            # rd stall, odram read osram
            # no wr stall
            osram_stall_rd += math.ceil(max(osram_this * osram_cfg["byte_per_word"] / min(osram_bw, noc_bw / N_split / K_split) - compute_cycle_c_delta, 0))
            noc_bytes += osram_this * osram_cfg["byte_per_word"] * N_split * K_split

            odram_stall += math.ceil(max((osram_this + odram_byte_iw) / odram_bw - compute_cycle_c_delta, 0))
    
    first_map_cols = num_rows
    last_map_rows = num_cols
    first_ififo_map = num_rows * num_cols / isram_cfg["byte_per_word"] / isram_bw
    first_wfifo_map = num_rows * num_cols / wsram_cfg["byte_per_word"] / wsram_bw

    # count for initial streaming latency
    compute_cycle += first_map_cols + last_map_rows
    compute_cycle += first_ififo_map + first_wfifo_map
    compute_cycle += stream_in_cycle + stream_out_cycle

    utilization = prob.flops * prob_cfg["cycle"] / (num_rows * num_cols * compute_cycle * 2 + num_cols * compute_cycle) * 100. / num_nodes

    compu_result = OrderedDict()
    compu_result["cycle"] = int(compute_cycle)
    compu_result["utilization"] = float(utilization)
    compu_result["flops"] = float(prob.flops)

    noc_result = copy.deepcopy(compu_result)
    noc_result["cycle"] = int(noc_bytes / noc_bw)
    
    ififo_result = copy.deepcopy(compu_result)
    wfifo_result = copy.deepcopy(compu_result)
    ofifo_result = copy.deepcopy(compu_result)

    isram_result = OrderedDict()
    isram_result["access stall"]    = OrderedDict({ "rd": int(isram_stall_rd), 
                                                    "wr": int(isram_stall_wr)})
    isram_result["byte total"]      = OrderedDict({ "rd": int(math.ceil(isram_word_rd * isram_cfg["byte_per_word"] * num_nodes)), 
                                                    "wr": int(math.ceil(isram_word_wr * isram_cfg["byte_per_word"] * num_nodes))})
    isram_result["access total"]    = OrderedDict({ "rd": int(math.ceil(isram_word_rd * isram_cfg["byte_per_word"] * num_nodes / isram_cfg["byte_per_brow"])), 
                                                    "wr": int(math.ceil(isram_word_wr * isram_cfg["byte_per_word"] * num_nodes / isram_cfg["byte_per_brow"]))})
    
    wsram_result = OrderedDict()
    wsram_result["access stall"]    = OrderedDict({ "rd": int(wsram_stall_rd), 
                                                    "wr": int(wsram_stall_wr)})
    wsram_result["byte total"]      = OrderedDict({ "rd": int(math.ceil(wsram_word_rd * wsram_cfg["byte_per_word"] * num_nodes)), 
                                                    "wr": int(math.ceil(wsram_word_wr * wsram_cfg["byte_per_word"] * num_nodes))})
    wsram_result["access total"]    = OrderedDict({ "rd": int(math.ceil(wsram_word_rd * wsram_cfg["byte_per_word"] * num_nodes / wsram_cfg["byte_per_brow"])), 
                                                    "wr": int(math.ceil(wsram_word_wr * wsram_cfg["byte_per_word"] * num_nodes / wsram_cfg["byte_per_brow"]))})

    osram_result = OrderedDict()
    osram_result["access stall"]    = OrderedDict({ "rd": int(osram_stall_rd), 
                                                    "wr": int(osram_stall_wr)})
    osram_result["byte total"]      = OrderedDict({ "rd": int(math.ceil(osram_word_rd * osram_cfg["byte_per_word"] * num_nodes)), 
                                                    "wr": int(math.ceil(osram_word_wr * osram_cfg["byte_per_word"] * num_nodes))})
    osram_result["access total"]    = OrderedDict({ "rd": int(math.ceil(osram_word_rd * osram_cfg["byte_per_word"] * num_nodes / osram_cfg["byte_per_brow"])), 
                                                    "wr": int(math.ceil(osram_word_wr * osram_cfg["byte_per_word"] * num_nodes / osram_cfg["byte_per_brow"]))})

    odram_result = OrderedDict()
    odram_byte_rd = isram_word_wr * isram_cfg["byte_per_word"] * N_split + wsram_word_wr * wsram_cfg["byte_per_word"] * K_split
    odram_byte_wr = prob.N * prob.P * prob.Q * prob.K * osram_cfg["byte_per_word"]
    odram_result["access stall"]    = OrderedDict({ "rd": 0 if odram_byte_rd == 0 else int(math.ceil(odram_stall * odram_byte_rd / (odram_byte_rd + odram_byte_wr))),
                                                    "wr": 0 if odram_byte_wr == 0 else int(math.ceil(odram_stall * odram_byte_wr / (odram_byte_rd + odram_byte_wr)))})
    odram_result["byte total"]      = OrderedDict({ "rd": int(odram_byte_rd), 
                                                    "wr": int(odram_byte_wr)})
    odram_result["access total"]    = OrderedDict({ "rd": int(math.ceil(odram_byte_rd / odram_cfg["bits_per_chip"])), 
                                                    "wr": int(math.ceil(odram_byte_wr / odram_cfg["bits_per_chip"]))})
    odram_result["page total"]      = OrderedDict({ "rd": int(math.ceil(odram_byte_rd * 8 / odram_cfg["bits_per_page"])), 
                                                    "wr": int(math.ceil(odram_byte_wr * 8 / odram_cfg["bits_per_page"]))})

    output["compu"] = compu_result
    output["noc"] = noc_result
    output["ififo"] = ififo_result
    output["wfifo"] = wfifo_result
    output["ofifo"] = ofifo_result
    output["oaccu"] = ofifo_result
    output["isram"] = isram_result
    output["wsram"] = wsram_result
    output["osram"] = osram_result
    output["odram"] = odram_result

    return output

