from collections import OrderedDict
import numpy as np
from workload import conv
import math
from utils import bcolors, round2power
import copy


def dataflow(
    arch_cfg: OrderedDict,
    prob_cfg: OrderedDict
):
    """
    This function implements the weight stationary systolic array.
    All memories are double buffered.
    
    Tensor format:
    
    Tensorflow-preferred format: (systolic arry)
    Input format: NXYC
    Weight format: RSCK
    Output format: NPQK

    Please refer to https://inst.eecs.berkeley.edu/~eecs151/sp20/files/lec18-DNN-QijingHuang.pdf page 29 for how to schedule GEMM.
    """

    # intra-image parallel with weight staionary
    compu_cfg = arch_cfg["compu"]
    ififo_cfg = arch_cfg["ififo"]
    wfifo_cfg = arch_cfg["wfifo"]
    ofifo_cfg = arch_cfg["ofifo"]
    isram_cfg = arch_cfg["isram"]
    wsram_cfg = arch_cfg["wsram"]
    osram_cfg = arch_cfg["osram"]
    odram_cfg = arch_cfg["odram"]

    # output shall include the ordered output of all components
    output = OrderedDict()

    num_rows = compu_cfg["num_instances"][0]
    num_cols = compu_cfg["num_instances"][1]

    # initialize the problem
    prob = conv("", prob_cfg)

    # dram bandwidth
    # byte per cycle
    odram_bw = arch_cfg["odram"]["bandwidth"] * 10**9 / (arch_cfg["frequency"] * 10**6)

    # sram bandwidth
    isram_bw = isram_cfg["byte_per_word"] * num_rows
    wsram_bw = wsram_cfg["byte_per_word"] * num_cols
    osram_bw = osram_cfg["byte_per_word"] * num_cols
    
    # sram traffic
    isram_word_rd, isram_word_wr = 0, 0
    wsram_word_rd, wsram_word_wr = 0, 0
    osram_word_rd, osram_word_wr = 0, 0

    # fifo check
    assert ofifo_cfg["num_instances"][0] == 2, \
        bcolors.FAIL + "The output fifo is not double buffered." + bcolors.ENDC
    assert ofifo_cfg["num_instances"][1] == num_cols, \
        bcolors.FAIL + "The output fifo count is unequal to the array column count." + bcolors.ENDC

    assert ififo_cfg["num_instances"][1] == 2, \
        bcolors.FAIL + "The input fifo is not double buffered." + bcolors.ENDC
    assert ififo_cfg["num_instances"][0] == num_rows, \
        bcolors.FAIL + "The input fifo count is unequal to the array row count." + bcolors.ENDC

    assert wfifo_cfg["num_instances"][0] == 2, \
        bcolors.FAIL + "The weight fifo is not double buffered." + bcolors.ENDC
    assert wfifo_cfg["num_instances"][1] == num_cols, \
        bcolors.FAIL + "The weight fifo count is unequal to the array column count." + bcolors.ENDC

    # total compute cycle
    compute_cycle = 0

    # stall cycle
    odram_stall = 0
    isram_stall_rd, isram_stall_wr = 0, 0
    wsram_stall_rd, wsram_stall_wr = 0, 0
    osram_stall_rd, osram_stall_wr = 0, 0

    # initialization cycle
    stream_in_flag = True
    stream_in_cycle = 0
    stream_out_cycle = 0
    first_map_cols, last_map_rows = 0, 0
    first_ififo_map, first_wfifo_map = 0, 0

    # search k_max
    isram_word = isram_cfg["byte_tot_sram"] / isram_cfg["byte_per_word"] / 2
    wsram_word = wsram_cfg["byte_tot_sram"] / wsram_cfg["byte_per_word"] / 2
    osram_word = osram_cfg["byte_tot_sram"] / osram_cfg["byte_per_word"] / 2
    
    total_util_max = 0

    n_max_util = prob.N
    k_max_util = prob.K
    c_max_util = prob.C
    p_max_util = prob.P
    q_max_util = prob.Q
    x_max_util = prob.X
    y_max_util = prob.Y

    success_flag = False

    # this tile search prioritizes the compute utilization, and it also checks memory utilization
    # prefer to large K and N, and small C
    k_appr = round2power(prob.K, mode='ceil')
    k_loops = int(math.ceil(math.log(prob.K, 2))) + 1
    for k_idx in range(k_loops):
        c_appr_max = math.floor(wsram_cfg["byte_tot_sram"] / wsram_cfg["byte_per_word"] / 2 / k_appr / (prob.R * prob.S))
        c_appr_max = round2power(min(c_appr_max, prob.C))
        c_appr = 1
        c_loops = int(math.ceil(math.log(c_appr_max / c_appr, 2))) + 1
        for c_idx in range(c_loops):
            n_appr = round2power(prob.N)
            n_loops = int(math.ceil(math.log(prob.N, 2))) + 1
            for n_idx in range(n_loops):
                pq_appr = osram_cfg["byte_tot_sram"] / osram_cfg["byte_per_word"] / 2 / n_appr / k_appr
                if pq_appr < 1:
                    continue
                pq_appr = round2power(min(pq_appr, prob.P * prob.Q))
                # assume pq is square
                p_appr = round2power(min(math.sqrt(pq_appr), prob.P))
                q_appr = round2power(min(pq_appr / p_appr, prob.Q))

                xy_appr = isram_cfg["byte_tot_sram"] / isram_cfg["byte_per_word"] / 2 / n_appr / c_appr
                if xy_appr < 1:
                    continue
                xy_appr = min(xy_appr, prob.X * prob.Y)
                
                # array utilization
                # k map to col
                k2col_col_ratio = (prob.K / 2**k_idx) / num_cols
                k2col_col_util = k2col_col_ratio / math.ceil(k2col_col_ratio)

                # # c to col
                # k2col_row_ratio = (prob.C / 2**c_idx) / num_rows
                # k2col_row_util = k2col_row_ratio / math.ceil(k2col_row_ratio)
                # array_util = k2col_col_util * k2col_row_util
                array_util = k2col_col_util

                # memory utilization
                nxyc_appr = n_appr * xy_appr * c_appr
                rsck_appr = prob.R * prob.S * c_appr * k_appr
                npqk_appr = n_appr * pq_appr * k_appr

                isram_util = nxyc_appr / isram_word
                wsram_util = rsck_appr / wsram_word
                osram_util = npqk_appr / osram_word

                # isram should hold at least one XY
                isram_util_bad = (isram_util == 0)
                wsram_util_bad = (wsram_util > 1) or (wsram_util == 0)
                osram_util_bad = (osram_util > 1) or (osram_util == 0)
                xsram_util_bad = isram_util_bad + wsram_util_bad or osram_util_bad
                xsram_util = (isram_util + wsram_util + osram_util) / 3

                # weighted total utilization
                total_util = array_util + xsram_util * 0.05

                if (total_util > total_util_max) and (not xsram_util_bad):
                    success_flag = True
                    n_max_util = n_appr
                    k_max_util = k_appr
                    c_max_util = c_appr
                    p_max_util = p_appr
                    q_max_util = q_appr
                    x_max_util = (p_appr - 1) * prob.Hstride + prob.R
                    y_max_util = (q_appr - 1) * prob.Wstride + prob.S
                    array_util_max_util = array_util
                    total_util_max = total_util
                    # print(total_util, array_util)
                
                n_appr = n_appr / 2
            c_appr = c_appr * 2
        k_appr = k_appr / 2

    assert success_flag
    # print(" array_util: ", array_util_max_util, " npqk: (", n_max_util, p_max_util, q_max_util, k_max_util, ")", " c: ", c_max_util, " row-col: (", num_rows, num_cols, ")")

    k_max = k_max_util
    k_map = np.divmod(prob.K, k_max)
    for k in range(int(k_map[0] + 1)):
        k_this = k_max if k < k_map[0] else k_map[1]
        if k_this == 0:
            break
        
        n_this = prob.N
        p_max = p_max_util
        q_max = q_max_util
        npq_max = n_this * p_max * q_max
        npq_map = np.divmod(prob.N * prob.P * prob.Q, npq_max)
        for npq in range(int(npq_map[0] + 1)):
            npq_this = npq_max if npq < npq_map[0] else npq_map[1]
            if npq_this == 0:
                break
            
            pq_this = min(npq_this / n_this, prob.P * prob.Q)
            # assume pq is square
            p_this = min(round2power(math.sqrt(pq_this)), prob.P)
            q_this = min(pq_this / p_this, prob.Q)
            osram_this = n_this * k_this * pq_this
            # stream out cycle is per n tile, i.e., n_max_util
            stream_out_cycle = (n_max_util * p_this * q_this * k_this) * osram_cfg["byte_per_word"] / min(osram_bw, odram_bw)

            x_this = (p_this - 1) * prob.Hstride + prob.R
            y_this = (q_this - 1) * prob.Wstride + prob.S

            odram_byte_iw = 0
            compute_cycle_c_delta = 0
            
            c_max = c_max_util
            c_map = np.divmod(prob.C, c_max)
            for c in range(int(c_map[0] + 1)):
                c_this = c_max if c < c_map[0] else c_map[1]
                if c_this == 0:
                    break
                # print(c_this)
                under_util_isram_bw = math.ceil(num_rows / c_this)
                under_util_wosram_bw = math.ceil(num_cols / k_this)
                row_fold = math.ceil(c_this / num_rows) * prob.R * prob.S
                row_fold_cycle = row_fold * prob_cfg["cycle"]
                col_fold = math.ceil(k_this / num_cols)
                compute_cycle_c = npq_this * row_fold_cycle * col_fold
                compute_cycle += compute_cycle_c
                compute_cycle_c_delta += compute_cycle_c

                # for every c tile, outputs are both loaded and read once by the array/ofifo
                # no stall will happen for this
                osram_word_wr += osram_this
                osram_word_rd += osram_this

                # for every c tile, weights are loaded only once
                wsram_this = c_this * (prob.R * prob.S) * k_this
                wsram_word_wr += wsram_this
                wsram_word_rd += wsram_this

                # no read stall
                wsram_bytes = wsram_this * wsram_cfg["byte_per_word"]
                wsram_stall_wr += math.ceil(max(wsram_bytes / (wsram_bw / under_util_wosram_bw) - compute_cycle_c, 0))

                isram_this = n_this * x_this * y_this * c_this
                isram_word_wr += isram_this
                isram_word_rd += isram_this
                
                if stream_in_flag:
                    stream_in_cycle =   isram_this * isram_cfg["byte_per_word"] / min(isram_bw / under_util_isram_bw, odram_bw) + \
                                        wsram_this * wsram_cfg["byte_per_word"] / min(wsram_bw / under_util_wosram_bw, odram_bw)
                    stream_in_flag = False

                # no read stall
                isram_bytes = isram_this * isram_cfg["byte_per_word"]
                isram_stall_wr += math.ceil(max(isram_bytes / (isram_bw / under_util_isram_bw) - compute_cycle_c, 0))

                odram_byte_iw += wsram_bytes + isram_bytes

            # rd stall, odram read osram
            # no wr stall
            osram_stall_rd += math.ceil(max(osram_this * osram_cfg["byte_per_word"] / (osram_bw / under_util_wosram_bw) - compute_cycle_c_delta, 0))

            odram_stall += math.ceil(max((osram_this + odram_byte_iw) / odram_bw - compute_cycle_c_delta, 0))
    
    first_map_cols = num_rows
    last_map_rows = num_cols
    first_ififo_map = num_rows * num_cols / isram_cfg["byte_per_word"] / isram_bw
    first_wfifo_map = num_rows * num_cols / wsram_cfg["byte_per_word"] / wsram_bw

    # count for initial streaming latency
    compute_cycle += first_map_cols + last_map_rows
    compute_cycle += first_ififo_map + first_wfifo_map
    compute_cycle += stream_in_cycle + stream_out_cycle

    utilization = prob.flops * prob_cfg["cycle"] / (num_rows * num_cols * compute_cycle * 2 + num_cols * compute_cycle) * 100.

    compu_result = OrderedDict()
    compu_result["cycle"] = int(compute_cycle)
    compu_result["utilization"] = float(utilization)
    compu_result["flops"] = float(prob.flops)

    ififo_result = copy.deepcopy(compu_result)
    wfifo_result = copy.deepcopy(compu_result)
    ofifo_result = copy.deepcopy(compu_result)

    isram_result = OrderedDict()
    isram_result["access stall"]    = OrderedDict({ "rd": int(isram_stall_rd), 
                                                    "wr": int(isram_stall_wr)})
    isram_result["byte total"]      = OrderedDict({ "rd": int(math.ceil(isram_word_rd * isram_cfg["byte_per_word"])), 
                                                    "wr": int(math.ceil(isram_word_wr * isram_cfg["byte_per_word"]))})
    isram_result["access total"]    = OrderedDict({ "rd": int(math.ceil(isram_word_rd * isram_cfg["byte_per_word"] / isram_cfg["byte_per_brow"])), 
                                                    "wr": int(math.ceil(isram_word_wr * isram_cfg["byte_per_word"] / isram_cfg["byte_per_brow"]))})
    
    wsram_result = OrderedDict()
    wsram_result["access stall"]    = OrderedDict({ "rd": int(wsram_stall_rd), 
                                                    "wr": int(wsram_stall_wr)})
    wsram_result["byte total"]      = OrderedDict({ "rd": int(math.ceil(wsram_word_rd * wsram_cfg["byte_per_word"])), 
                                                    "wr": int(math.ceil(wsram_word_wr * wsram_cfg["byte_per_word"]))})
    wsram_result["access total"]    = OrderedDict({ "rd": int(math.ceil(wsram_word_rd * wsram_cfg["byte_per_word"] / wsram_cfg["byte_per_brow"])), 
                                                    "wr": int(math.ceil(wsram_word_wr * wsram_cfg["byte_per_word"] / wsram_cfg["byte_per_brow"]))})

    osram_result = OrderedDict()
    osram_result["access stall"]    = OrderedDict({ "rd": int(osram_stall_rd), 
                                                    "wr": int(osram_stall_wr)})
    osram_result["byte total"]      = OrderedDict({ "rd": int(math.ceil(osram_word_rd * osram_cfg["byte_per_word"])), 
                                                    "wr": int(math.ceil(osram_word_wr * osram_cfg["byte_per_word"]))})
    osram_result["access total"]    = OrderedDict({ "rd": int(math.ceil(osram_word_rd * osram_cfg["byte_per_word"] / osram_cfg["byte_per_brow"])), 
                                                    "wr": int(math.ceil(osram_word_wr * osram_cfg["byte_per_word"] / osram_cfg["byte_per_brow"]))})

    odram_result = OrderedDict()
    odram_byte_rd = isram_word_wr * isram_cfg["byte_per_word"] + wsram_word_wr * wsram_cfg["byte_per_word"]
    odram_byte_wr = prob.N * prob.P * prob.Q * prob.K * osram_cfg["byte_per_word"]
    odram_result["access stall"]    = OrderedDict({ "rd": 0 if odram_byte_rd == 0 else int(math.ceil(odram_stall * odram_byte_rd / (odram_byte_rd + odram_byte_wr))),
                                                    "wr": 0 if odram_byte_wr == 0 else int(math.ceil(odram_stall * odram_byte_wr / (odram_byte_rd + odram_byte_wr)))})
    odram_result["byte total"]      = OrderedDict({ "rd": int(odram_byte_rd), 
                                                    "wr": int(odram_byte_wr)})
    odram_result["access total"]    = OrderedDict({ "rd": int(math.ceil(odram_byte_rd / odram_cfg["bits_per_chip"])), 
                                                    "wr": int(math.ceil(odram_byte_wr / odram_cfg["bits_per_chip"]))})
    odram_result["page total"]      = OrderedDict({ "rd": int(math.ceil(odram_byte_rd * 8 / odram_cfg["bits_per_page"])), 
                                                    "wr": int(math.ceil(odram_byte_wr * 8 / odram_cfg["bits_per_page"]))})

    output["compu"] = compu_result
    output["ififo"] = ififo_result
    output["wfifo"] = wfifo_result
    output["ofifo"] = ofifo_result
    output["oaccu"] = ofifo_result
    output["isram"] = isram_result
    output["wsram"] = wsram_result
    output["osram"] = osram_result
    output["odram"] = odram_result

    return output

