from collections import OrderedDict
import numpy as np
from workload import conv
import math
from utils import bcolors
import copy


def dataflow(
    arch_cfg: OrderedDict,
    prob_cfg: OrderedDict
):
    """
    This function implements the carat architecture with inter-image data schedule.
    All memories are double buffered.
    
    Tensor format:

    Tensorflow-preferred format: (systolic arry)
    Input format: NXYC
    Weight format: RSCK
    Output format: NPQK

    PyTorch-preferred format: (SIMD)
    Input format: NCXY
    Weight format: KCRS
    Output format: NKPQ

    Please refer to https://inst.eecs.berkeley.edu/~eecs151/sp20/files/lec18-DNN-QijingHuang.pdf page 29 for how to schedule GEMM.
    """

    # inter-image parallel with weight staionary

    compu_cfg = arch_cfg["compu"]
    ififo_cfg = arch_cfg["ififo"]
    wfifo_cfg = arch_cfg["wfifo"]
    ofifo_cfg = arch_cfg["ofifo"]
    isram_cfg = arch_cfg["isram"]
    wsram_cfg = arch_cfg["wsram"]
    osram_cfg = arch_cfg["osram"]
    odram_cfg = arch_cfg["odram"]

    num_rows = compu_cfg["num_instances"][0]
    num_cols = compu_cfg["num_instances"][1]

    # output shall include the ordered output of all components
    output = OrderedDict()

    # initialize the problem
    prob = conv("", prob_cfg)

    # fanout of vertical weights
    fanout = 4

    # dram bandwidth in byte per cycle
    odram_bw = arch_cfg["odram"]["bandwidth"] * 10**9 / (arch_cfg["frequency"] * 10**6)

    # isram traffic total
    isram_word_rd, isram_word_wr = 0, 0
    isram_bw = isram_cfg["byte_per_brow"] * isram_cfg["bank_per_sram"]

    # wsram traffic total
    wsram_word_rd, wsram_word_wr = 0, 0
    wsram_bw = wsram_cfg["byte_per_brow"] * wsram_cfg["bank_per_sram"]
    
    # osram traffic total, output is stationary is the accumulation fifo, so transfered once
    osram_word_rd, osram_word_wr = 0, 0
    osram_bw = osram_cfg["byte_per_brow"] * osram_cfg["bank_per_sram"]

    # analyze tile size based on buffer size
    # npq is the inner most tile, and we use this as a start
    # 1.1 calculate npq tile size based on ofifo size
    assert ofifo_cfg["num_instances"][1] == 2, \
        bcolors.FAIL + "The output fifo is not double buffered." + bcolors.ENDC
    assert ofifo_cfg["num_instances"][0] == num_rows, \
        bcolors.FAIL + "The output fifo count is unequal to the array row count." + bcolors.ENDC
    # max output can be hold for this tile due to limited ofifo depth
    ts_o_bd = ofifo_cfg["instance"]["queue"]["depth"]

    assert ts_o_bd == num_cols, \
        bcolors.FAIL + "Unmatched output fifo size and array width." + bcolors.ENDC

    # 1.2 calculate npq tile size based on ififo size
    assert ififo_cfg["num_instances"][1] == 2, \
        bcolors.FAIL + "The input fifo is not double buffered." + bcolors.ENDC
    assert ififo_cfg["num_instances"][0] == num_rows, \
        bcolors.FAIL + "The input fifo count is unequal to the array row count." + bcolors.ENDC
    # max output can be hold for this tile due to limited ififo depth
    ts_i_bd = ififo_cfg["instance"]["queue"]["depth"]
    assert ts_i_bd >= (num_rows / fanout / prob.cycle), \
        bcolors.FAIL + "The input fifo depth is not enough to hide the compute latency." + bcolors.ENDC

    # 1.3 calculate crs tile size based on wfifo size
    assert wfifo_cfg["num_instances"][0] == 2, \
        bcolors.FAIL + "The weight fifo is not double buffered." + bcolors.ENDC
    assert wfifo_cfg["num_instances"][1] == num_cols, \
        bcolors.FAIL + "The weight fifo count is unequal to the array column count." + bcolors.ENDC
    # max output can be hold for this tile due to limited wfifo depth
    ts_w_bd = wfifo_cfg["instance"]["queue"]["depth"]

    # tile k according to wsram
    # weight size per output channel
    w_sz = prob.C * prob.R * prob.S
    # max k value bounded by wsram
    k_max = math.floor(wsram_cfg["byte_tot_sram"] / wsram_cfg["byte_per_word"] / 2 / w_sz)
    w_map = np.divmod(prob.K, k_max)

    # prioritize Q more than N
    # use Y_valid ensures that one Q can be generated
    n_max_i = math.floor(isram_cfg["byte_tot_sram"] / isram_cfg["byte_per_word"] / 2 / (prob.Y_valid * prob.R))
    n_max_o = math.floor(osram_cfg["byte_tot_sram"] / osram_cfg["byte_per_word"] / 2 / (num_cols * prob.Q))
    n_max = 2**math.floor(math.log2(max(min(n_max_i, n_max_o, prob.N), 1)))

    p_max_i = math.floor(isram_cfg["byte_tot_sram"] / isram_cfg["byte_per_word"] / 2 / (prob.Y_valid * prob.R * n_max))
    p_max_o = math.floor(osram_cfg["byte_tot_sram"] / osram_cfg["byte_per_word"] / 2 / (num_cols * prob.Q * n_max))
    p_max = max(min(p_max_i, p_max_o, prob.P), 1)

    c_max_i = math.floor(isram_cfg["byte_tot_sram"] / isram_cfg["byte_per_word"] / 2 / (prob.Y_valid * prob.R * n_max * p_max))
    c_max_o = math.floor(osram_cfg["byte_tot_sram"] / osram_cfg["byte_per_word"] / 2 / (num_cols * prob.Q * n_max * p_max))
    c_max = max(min(c_max_i, c_max_o, prob.C), 1)
    
    n_map = np.divmod(prob.N, n_max)
    p_map = np.divmod(prob.P, p_max)
    c_map = np.divmod(prob.C, c_max)

    compute_cycle = 0
    stall_cycle = 0
    stream_in_cycle = 0
    stream_out_cycle = 0

    # assume only sram stall, i.e., fifos do not stall due to double buffer
    isram_stall_rd, isram_stall_wr = 0, 0
    isram_flag_tile_first_wr = True

    wsram_stall_rd, wsram_stall_wr = 0, 0
    wsram_flag_tile_first_wr = True

    osram_stall_rd, osram_stall_wr = 0, 0

    odram_stall = 0

    first_map_cols = 0
    first_map_flag = True
    last_map_rows = num_cols

    first_ififo_map, first_wfifo_map = 0, 0
    first_ififo_flag, first_wfifo_flag = True, True

    # weight stationary forces kcrs as outer loop, and npq as inner loop
    for w in range(int(w_map[0] + 1)):
        # weight from odram to wsram
        k_wsram = k_max if w < w_map[0] else w_map[1]
        if k_wsram == 0:
            break
        wsram_word_tile_wr = k_wsram * w_sz
        wsram_word_wr += wsram_word_tile_wr
        wsram_cycl_tile_wr = 0

        if wsram_flag_tile_first_wr is True:
            stream_in_cycle += math.ceil(wsram_word_tile_wr * wsram_cfg["byte_per_word"] / odram_bw)
            wsram_flag_tile_first_wr = False

        odram_byte_io = 0
        # each K is mapped to one column
        col_map = np.divmod(k_wsram, num_cols)
        for k in range(int(col_map[0] + 1)):
            # each column works on one output channel, this is spatial mapping
            map_cols = num_cols if k < col_map[0] else col_map[1]
            if map_cols == 0:
                    break

            if first_map_flag is True:
                # initial fifo streaming cycle
                first_map_cols = num_rows / fanout
                first_map_flag = False

            for n in range(int(n_map[0] + 1)):
                n_now = n_max if n < n_map[0] else n_map[1]
                if n_now == 0:
                    break

                for p in range(int(p_map[0] + 1)):
                    p_now = p_max if p < p_map[0] else p_map[1]
                    if p_now == 0:
                        break
                    
                    n_cycle = 0

                    # accumulation in ofifo is output stationary
                    npq_map = np.divmod(n_now * p_now * prob.Q, num_rows)

                    # accumulation both reads and writes osram
                    osram_word_tile_rd = n_now * map_cols * p_now * prob.Q
                    osram_word_rd += osram_word_tile_rd
                    osram_cycl_tile_rd = 0

                    osram_word_tile_wr = n_now * map_cols * p_now * prob.Q
                    osram_word_wr += osram_word_tile_wr
                    osram_cycl_tile_wr = 0

                    # only the last stream out cycle needs to be counted
                    stream_out_cycle = math.ceil(osram_word_tile_rd * osram_cfg["byte_per_word"] / min(odram_bw, osram_bw))

                    wsram_word_tile_rd = w_sz * map_cols * (1 if (ts_w_bd * 2 >= w_sz) else int(npq_map[0] + math.ceil(npq_map[1] / num_rows)))
                    wsram_word_rd += wsram_word_tile_rd
                    wsram_cycl_tile_rd = 0

                    X_now = (prob.Hstride * (p_now - 1) + prob.R) if prob.Hstride < prob.R else (prob.R * p_now)
                    
                    for c in range(int(c_map[0] + 1)):
                        c_now = c_max if c < c_map[0] else c_map[1]
                        if c_now == 0:
                            break

                        isram_word_tile_wr = n_now * c_now * X_now * prob.Y_valid
                        isram_word_wr += isram_word_tile_wr
                        isram_cycl_tile_wr = 0

                        if isram_flag_tile_first_wr is True:
                            stream_in_cycle += math.ceil(isram_word_tile_wr * isram_cfg["byte_per_word"] / min(odram_bw, isram_bw))
                            isram_flag_tile_first_wr = False

                        for npq in range(int(npq_map[0] + 1)):
                            map_rows = (num_rows if npq < npq_map[0] else npq_map[1])
                            if map_rows == 0:
                                break
                            stream_cycle = prob.cycle * c_now * prob.R * prob.S
                            
                            isram_word_tile_rd = map_rows * c_now * prob.R * prob.S
                            isram_word_rd += isram_word_tile_rd
                            isram_cycl_tile_rd = 0

                            if first_ififo_flag is True:
                                first_ififo_map = math.ceil(map_rows * ts_i_bd * isram_cfg["byte_per_word"] / isram_bw)
                                first_ififo_flag = False

                            if first_wfifo_flag is True:
                                first_wfifo_map = math.ceil(map_cols * ts_w_bd * wsram_cfg["byte_per_word"] / wsram_bw)
                                first_wfifo_flag = False

                            # compute cycle, this value considers the cycles for both streaming in and out the data, however, map_cols can be perfectly hided
                            tile_cycle = stream_cycle
                            compute_cycle += tile_cycle

                            isram_cycl_tile_rd += tile_cycle
                            isram_cycl_tile_wr += tile_cycle

                            wsram_cycl_tile_rd += tile_cycle
                            wsram_cycl_tile_wr += tile_cycle

                            osram_cycl_tile_rd += tile_cycle
                            osram_cycl_tile_wr += tile_cycle

                            n_cycle += tile_cycle
                                
                            isram_stall_rd += math.ceil(max(isram_word_tile_rd * isram_cfg["byte_per_word"] / isram_bw - isram_cycl_tile_rd, 0))
                            
                        isram_stall_wr += math.ceil(max(isram_word_tile_wr * isram_cfg["byte_per_word"] / isram_bw - isram_cycl_tile_wr, 0))
                        odram_byte_io += isram_word_tile_wr * isram_cfg["byte_per_word"]

                    wsram_stall_rd += math.ceil(max(wsram_word_tile_rd * wsram_cfg["byte_per_word"] / wsram_bw - wsram_cycl_tile_rd, 0))

                    osram_stall_rd += math.ceil(max(osram_word_tile_rd * osram_cfg["byte_per_word"] / osram_bw - osram_cycl_tile_rd, 0))
                    osram_stall_wr += math.ceil(max(osram_word_tile_wr * osram_cfg["byte_per_word"] / osram_bw - osram_cycl_tile_wr, 0))
                    
                    odram_byte_io += osram_word_tile_rd * osram_cfg["byte_per_word"]
        
        odram_stall += math.ceil(max((odram_byte_io + wsram_word_tile_wr * wsram_cfg["byte_per_word"]) / odram_bw - wsram_cycl_tile_wr, 0))
        
    # count for initial streaming latency
    compute_cycle += first_map_cols + last_map_rows
    compute_cycle += first_ififo_map + first_wfifo_map
    compute_cycle += stream_in_cycle + stream_out_cycle

    utilization = prob.flops / (num_rows * 2 / (prob.cycle / num_cols) * compute_cycle) * 100.

    compu_result = OrderedDict()
    compu_result["cycle"] = int(compute_cycle)
    compu_result["utilization"] = float(utilization)
    compu_result["flops"] = float(prob.flops)

    ififo_result = copy.deepcopy(compu_result)
    wfifo_result = copy.deepcopy(compu_result)
    ofifo_result = copy.deepcopy(compu_result)

    isram_result = OrderedDict()
    isram_result["access stall"]    = OrderedDict({ "rd": int(isram_stall_rd), 
                                                    "wr": int(isram_stall_wr)})
    isram_result["byte total"]      = OrderedDict({ "rd": int(math.ceil(isram_word_rd * isram_cfg["byte_per_word"])), 
                                                    "wr": int(math.ceil(isram_word_wr * isram_cfg["byte_per_word"]))})
    isram_result["access total"]    = OrderedDict({ "rd": int(math.ceil(isram_word_rd * isram_cfg["byte_per_word"] / isram_cfg["byte_per_brow"])), 
                                                    "wr": int(math.ceil(isram_word_wr * isram_cfg["byte_per_word"] / isram_cfg["byte_per_brow"]))})
    
    wsram_result = OrderedDict()
    wsram_result["access stall"]    = OrderedDict({ "rd": int(wsram_stall_rd), 
                                                    "wr": int(wsram_stall_wr)})
    wsram_result["byte total"]      = OrderedDict({ "rd": int(math.ceil(wsram_word_rd * wsram_cfg["byte_per_word"])), 
                                                    "wr": int(math.ceil(wsram_word_wr * wsram_cfg["byte_per_word"]))})
    wsram_result["access total"]    = OrderedDict({ "rd": int(math.ceil(wsram_word_rd * wsram_cfg["byte_per_word"] / wsram_cfg["byte_per_brow"])), 
                                                    "wr": int(math.ceil(wsram_word_wr * wsram_cfg["byte_per_word"] / wsram_cfg["byte_per_brow"]))})

    osram_result = OrderedDict()
    osram_result["access stall"]    = OrderedDict({ "rd": int(osram_stall_rd), 
                                                    "wr": int(osram_stall_wr)})
    osram_result["byte total"]      = OrderedDict({ "rd": int(math.ceil(osram_word_rd * osram_cfg["byte_per_word"])), 
                                                    "wr": int(math.ceil(osram_word_wr * osram_cfg["byte_per_word"]))})
    osram_result["access total"]    = OrderedDict({ "rd": int(math.ceil(osram_word_rd * osram_cfg["byte_per_word"] / osram_cfg["byte_per_brow"])), 
                                                    "wr": int(math.ceil(osram_word_wr * osram_cfg["byte_per_word"] / osram_cfg["byte_per_brow"]))})

    odram_result = OrderedDict()
    odram_byte_rd = isram_word_wr * isram_cfg["byte_per_word"] + wsram_word_wr * wsram_cfg["byte_per_word"]
    odram_byte_wr = osram_word_rd * osram_cfg["byte_per_word"]
    odram_result["access stall"]    = OrderedDict({ "rd": 0 if odram_byte_rd == 0 else int(math.ceil(odram_stall * odram_byte_rd / (odram_byte_rd + odram_byte_wr))),
                                                    "wr": 0 if odram_byte_wr == 0 else int(math.ceil(odram_stall * odram_byte_wr / (odram_byte_rd + odram_byte_wr)))})
    odram_result["byte total"]      = OrderedDict({ "rd": int(odram_byte_rd), 
                                                    "wr": int(odram_byte_wr)})
    odram_result["access total"]    = OrderedDict({ "rd": int(math.ceil(odram_byte_rd / odram_cfg["bits_per_chip"])), 
                                                    "wr": int(math.ceil(odram_byte_wr / odram_cfg["bits_per_chip"]))})
    odram_result["page total"]      = OrderedDict({ "rd": int(math.ceil(odram_byte_rd * 8 / odram_cfg["bits_per_page"])), 
                                                    "wr": int(math.ceil(odram_byte_wr * 8 / odram_cfg["bits_per_page"]))})

    output["compu"] = compu_result
    output["ififo"] = ififo_result
    output["wfifo"] = wfifo_result
    output["ofifo"] = ofifo_result
    output["oaccu"] = ofifo_result
    output["waccu"] = ofifo_result
    output["itemp"] = ofifo_result
    output["osmux"] = ofifo_result
    output["bpipe"] = ofifo_result
    output["isram"] = isram_result
    output["wsram"] = wsram_result
    output["osram"] = osram_result
    output["odram"] = odram_result

    return output

