from collections import OrderedDict
import numpy as np
from workload import avgpool2d
import math
from utils import bcolors
import copy


def dataflow(
    arch_cnfg: OrderedDict,
    prob_cnfg: OrderedDict
):
    """
    This function implements the weight stationary dataflow for TPU-like systolic arrays.
    All memories are double buffered.
    Input format: NXYC
    Weight format: KRSC
    Output format: NPQK

    Input schedule: NXYC
    Output schedule: NPQK

    Channel last allows the isram and osram fetch to be sequential, reducing sram access count
    This dataflow leverages bit-packing for high sram efficiency.

    This dataflow converts conv2d to matrix vector multiplication by prioritizing channel-level parallelism.
    Please refer to https://inst.eecs.berkeley.edu/~eecs151/sp20/files/lec18-DNN-QijingHuang.pdf page 29.
    """

    # intra-image parallel with weight staionary

    compu_cnfg = arch_cnfg["compu"]
    spgen_cnfg = arch_cnfg["spgen"]
    sspar_cnfg = arch_cnfg["sspar"]
    ispar_cnfg = arch_cnfg["ispar"]
    ospar_cnfg = arch_cnfg["ospar"]
    ififo_cnfg = arch_cnfg["ififo"]
    wfifo_cnfg = arch_cnfg["wfifo"]
    leaky_cnfg = arch_cnfg["leaky"]
    reset_cnfg = arch_cnfg["reset"]
    isram_cnfg = arch_cnfg["isram"]
    wsram_cnfg = arch_cnfg["wsram"]
    osram_cnfg = arch_cnfg["osram"]
    usram_cnfg = arch_cnfg["usram"]
    odram_cnfg = arch_cnfg["odram"]

    num_rows = compu_cnfg["num_instances"][0]
    num_cols = compu_cnfg["num_instances"][1]
    num_pes = max(num_rows, num_cols)

    num_rows = compu_cnfg["num_instances"][0]
    num_cols = compu_cnfg["num_instances"][1]

    # output shall include the ordered output of all components
    output = OrderedDict()

    # initialize the problem
    prob = avgpool2d("", prob_cnfg)

    # ddr3 bandwidth: 12.8 GB/S at 400MHz
    # ddr3 bandwidth: 32 byte per cycle
    odram_bycy = 12.8 * 10**9 / (arch_cnfg["frequency"] * 10**6)

    # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 
    # define sparse format requirement
    # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 
    # isram/osram word is a spine index and a channel index
    # effective spike byte, for calculating required bandwidth
    spike_byte_spn = 3
    spike_byte_chn = 2
    spn_access_per_spn = 2

    # sram traffic
    # with compute, odram
    isram_byte_rd, isram_byte_wr = 0, 0
    isram_bycy = isram_cnfg["byte_per_brow"] * isram_cnfg["bank_per_sram"]
    prob_byte_i = prob.num_inputs_valid * (1 - prob.sparsity) * spike_byte_chn + prob.N * (prob.X * prob.Y + 1) * spike_byte_spn
    large_isram = isram_cnfg["byte_tot_sram"] / 2 >= prob_byte_i

    # wsram traffic total
    # with odram, compute
    wsram_word_rd, wsram_word_wr = 0, 0
    wsram_bycy = wsram_cnfg["byte_per_brow"] * wsram_cnfg["bank_per_sram"]
    
    # osram traffic total, output is stationary is the accumulation fifo, so transfered once
    # with odram, compute
    osram_byte_rd, osram_byte_wr = 0, 0
    osram_bycy = osram_cnfg["byte_per_brow"] * osram_cnfg["bank_per_sram"]
    prob_byte_o = prob.num_outputs * (1 - prob.sparsity) * spike_byte_chn + prob.N * (prob.P * prob.Q + 1) * spike_byte_spn
    large_osram = osram_cnfg["byte_tot_sram"] / 2 >= prob_byte_o

    # usram traffic total, output is stationary is the accumulation fifo, so transfered once
    # with odram, compute
    usram_word_rd, usram_word_wr = 0, 0
    usram_bycy = usram_cnfg["byte_per_brow"] * usram_cnfg["bank_per_sram"]
    large_usram = usram_cnfg["byte_tot_sram"] / usram_cnfg["byte_per_word"] / 2 >= (num_cols * prob.P * prob.Q)
    usram_odram_word_rd, usram_odram_word_wr = 0, 0

    # l0-fifo
    # 1. analyze max tile size
    # hint: given weight stationary dataflow, the tile size is bounded by input and output buffer size.
    # insight: for any compute logic, we can specify such constraints.

    # 1.2 calculate npq tile size based on ififo size
    assert ififo_cnfg["num_instances"][1] == 2, \
        bcolors.FAIL + "The input fifo is not double buffered." + bcolors.ENDC
    assert ififo_cnfg["num_instances"][0] == num_cols, \
        bcolors.FAIL + "The input fifo count is unequal to the array row count." + bcolors.ENDC
    # max output can be hold for this tile due to limited ififo depth
    ts_i_bd = ififo_cnfg["instance"]["queue"]["depth"]

    # 1.3 calculate crs tile size based on wfifo size
    assert wfifo_cnfg["num_instances"][0] == 2, \
        bcolors.FAIL + "The weight fifo is not double buffered." + bcolors.ENDC
    assert wfifo_cnfg["num_instances"][1] == num_cols, \
        bcolors.FAIL + "The weight fifo count is unequal to the array column count." + bcolors.ENDC
    # max output can be hold for this tile due to limited wfifo depth
    ts_w_bd = wfifo_cnfg["instance"]["queue"]["depth"]

    # The tile size in this design is 1 actually, as the gemm arith is done channel-wise
    ts_compute = 1

    # max k value bounded by num_cols
    k_max = num_cols
    w_map = np.divmod(prob.C, k_max)
    
    n_max_i = math.floor(isram_cnfg["byte_tot_sram"] / 2 / (prob.C * prob.Y_valid * (1 - prob.sparsity) * spike_byte_chn + (prob.Y_valid + 1) * spike_byte_spn) / prob.R)
    n_max_o = math.floor(osram_cnfg["byte_tot_sram"] / 2 / (num_cols * prob.Q * (1 - prob.sparsity) * spike_byte_chn + (prob.Q + 1) * spike_byte_spn))
    n_max_u = math.floor(usram_cnfg["byte_tot_sram"] / 2 / (num_cols * prob.Q * usram_cnfg["byte_per_word"]))
    n_max = max(min(n_max_i, n_max_o, n_max_u, prob.N), 1)

    p_max_i = math.floor(isram_cnfg["byte_tot_sram"] / 2 / ((prob.C * prob.Y_valid * (1 - prob.sparsity) * spike_byte_chn + (prob.Y_valid + 1) * spike_byte_spn) * n_max))
    p_max_o = math.floor(osram_cnfg["byte_tot_sram"] / 2 / ((num_cols * prob.Q * (1 - prob.sparsity) * spike_byte_chn + (prob.Q + 1) * spike_byte_spn) * n_max))
    p_max_u = math.floor(usram_cnfg["byte_tot_sram"] / 2 / (num_cols * prob.Q * usram_cnfg["byte_per_word"]))
    p_max = max(min(p_max_i, p_max_o, p_max_u, prob.P), 1)

    n_map = np.divmod(prob.N, n_max)
    p_map = np.divmod(prob.P, p_max)
    
    compute_cycle = 0
    stall_cycle = 0
    stream_in_cycle = 0
    stream_out_cycle = 0

    # assume only sram stall, i.e., fifos do not stall due to double buffer
    isram_stall_rd, isram_stall_wr = 0, 0
    isram_flag_tile_first_wr = True

    wsram_stall_rd, wsram_stall_wr = 0, 0
    wsram_flag_tile_first_wr = True

    osram_stall_rd, osram_stall_wr = 0, 0

    usram_stall_rd, usram_stall_wr = 0, 0

    odram_stall = 0

    first_map_cols = 0
    first_map_flag = True
    last_map_rows = 0
    
    total_util = 0
    
    first_ififo_map, first_wfifo_map = 0, 0
    first_ififo_flag, first_wfifo_flag = True, True

    # weight stationary forces kcrs as outer loop
    for w in range(int(w_map[0] + 1)):
        # weight from odram to wsram
        k_wsram = k_max if w < w_map[0] else w_map[1]
        if k_wsram == 0:
            break

        # channel dimension is actually splited into multiple chunks, each have maximum size of num_cols
        # this increase the overhead of the spine index
        # each K is mapped to one column
        col_map = np.divmod(k_wsram, num_cols)
        for k in range(int(col_map[0] + 1)):
            # each column works on one output channel, this is spatial mapping
            map_cols = num_cols if k < col_map[0] else col_map[1]
            if map_cols == 0:
                    break

            n_current = 0
            for n in range(int(n_map[0] + 1)):
                n_now = n_max if n < n_map[0] else n_map[1]
                if n_now == 0:
                    break
                n_current += n_now
                n_remain = prob.N - n_current

                for p in range(int(p_map[0] + 1)):
                    p_now = p_max if p < p_map[0] else p_map[1]
                    if p_now == 0:
                        break
                    
                    n_cycle = 0

                    # current tile with sparse computation
                    npq_map = np.divmod(n_now * p_now * prob.Q, ts_compute)

                    X_now = (prob.Hstride * (p_now - 1) + prob.R) if prob.Hstride < prob.R else (prob.R * p_now)
                    isram_byte_tile_wr = 0 if large_isram else \
                                        (n_now * (prob.C * X_now * prob.Y_valid * (1 - prob.sparsity) * spike_byte_chn + (X_now * prob.Y_valid + 1) * spike_byte_spn))
                    isram_byte_wr += isram_byte_tile_wr

                    isram_cycl_tile_wr = 0

                    if isram_flag_tile_first_wr is True:
                        stream_in_cycle += math.ceil(isram_byte_tile_wr / min(odram_bycy, isram_bycy))
                        isram_flag_tile_first_wr = False

                    osram_byte_tile_rd = 0 if large_osram else \
                                        (n_now * (map_cols * p_now * prob.Q * (1 - prob.sparsity) * spike_byte_chn + (p_now * prob.Q + 1) * spike_byte_spn))
                    osram_byte_rd += osram_byte_tile_rd
                    osram_cycl_tile_rd = 0

                    osram_byte_tile_wr = (n_now * (map_cols * p_now * prob.Q * (1 - prob.sparsity) * spike_byte_chn + (p_now * prob.Q + 1) * spike_byte_spn))
                    osram_byte_wr += osram_byte_tile_wr
                    osram_cycl_tile_wr = 0

                    # if total time step is small than n_now, no need to read or write usram at all
                    usram_word_tile_rd = map_cols * p_now * prob.Q * (1 if n_current == n_now else 0)
                    usram_word_tile_wr = map_cols * p_now * prob.Q * (1 if n_remain > 0 else 0)

                    usram_word_rd += usram_word_tile_rd * (1 if large_usram else 2)
                    usram_word_wr += usram_word_tile_wr * (1 if large_usram else 2)
                    usram_odram_word_rd += usram_word_tile_rd * (0 if large_usram else 1)
                    usram_odram_word_wr += usram_word_tile_wr * (0 if large_usram else 1)

                    usram_cycl_tile_rd = 0
                    usram_cycl_tile_wr = 0

                    stream_out_cycle = math.ceil(osram_byte_tile_rd / min(odram_bycy, osram_bycy))

                    for npq in range(int(npq_map[0] + 1)):
                        stream_cycle = (ts_compute if npq < npq_map[0] else npq_map[1])
                        if stream_cycle == 0:
                            break
                        # accessing the input for each output is sequential
                        tile_cycle = stream_cycle * prob.R * prob.S * (1 - prob.sparsity) * (1 + spn_access_per_spn / (1 - prob.sparsity))
                        tile_cycle_overhead = max(map_cols - tile_cycle, 0)
                        
                        # here we apply channel-level parallelism for matrix-vector multiplication, the read amount depends on the ififo size
                        # no input reuse is assumed here
                        isram_cycl_tile_rd = 0
                        # as channel dim does continuous, the word read overhead is negligible
                        # by 2 is because of the hardware implementation need access the spine array twice 
                        # to retrieve the start and end address of the current spine
                        isram_byte_tile_rd = prob.C * prob.R * prob.S * (1 - prob.sparsity) * spike_byte_chn + prob.R * prob.S * spn_access_per_spn * spike_byte_spn
                        isram_byte_rd += isram_byte_tile_rd

                        if first_wfifo_flag is True:
                            first_wfifo_map = ts_w_bd
                            first_wfifo_flag = False

                        # compute cycle
                        # in this design, the double buffer latency of srams can not be by perfectly hided, due to extra cycles for spn access
                        tile_cycle = tile_cycle + tile_cycle_overhead
                        compute_cycle += tile_cycle

                        isram_cycl_tile_rd += tile_cycle
                        isram_cycl_tile_wr += tile_cycle

                        osram_cycl_tile_rd += tile_cycle
                        osram_cycl_tile_wr += tile_cycle

                        usram_cycl_tile_rd += tile_cycle
                        usram_cycl_tile_wr += tile_cycle

                        n_cycle += tile_cycle

                        total_util += map_cols * tile_cycle

                        isram_stall_rd += math.ceil(max(isram_byte_tile_rd / isram_bycy - isram_cycl_tile_rd, 0))
                    
                    isram_stall_wr += math.ceil(max(isram_byte_tile_wr / isram_bycy - isram_cycl_tile_wr, 0))

                    osram_stall_rd += math.ceil(max(osram_byte_tile_rd / osram_bycy - osram_cycl_tile_rd, 0))
                    osram_stall_wr += math.ceil(max(osram_byte_tile_wr / osram_bycy - osram_cycl_tile_wr, 0))

                    usram_stall_rd += math.ceil(max(usram_word_tile_rd * usram_cnfg["byte_per_word"] / usram_bycy - usram_cycl_tile_rd, 0))
                    usram_stall_wr += math.ceil(max(usram_word_tile_wr * usram_cnfg["byte_per_word"] / usram_bycy - usram_cycl_tile_wr, 0))
                    
                    odram_byte_io = isram_byte_tile_wr + osram_byte_tile_rd
                    odram_byte_u = (usram_odram_word_wr * usram_cnfg["byte_per_word"] + usram_odram_word_rd * usram_cnfg["byte_per_word"])
                    odram_stall += math.ceil(max((odram_byte_io + odram_byte_u) / odram_bycy - n_cycle, 0))

    # count for initial streaming latency
    compute_cycle += first_map_cols + last_map_rows
    compute_cycle += first_ififo_map + first_wfifo_map
    compute_cycle += stream_in_cycle + stream_out_cycle

    utilization = total_util / (num_rows * num_cols * compute_cycle) * 100.

    compu_result = OrderedDict()
    compu_result["cycle"] = int(compute_cycle)
    compu_result["utilization"] = float(utilization)
    compu_result["flops"] = float(prob.flops)

    spgen_result = copy.deepcopy(compu_result)
    spgen_result["cycle"] = 0
    sspar_result = copy.deepcopy(compu_result)
    sspar_result["cycle"] = 0
    ispar_result = copy.deepcopy(compu_result)
    ospar_result = copy.deepcopy(compu_result)
    ospar_result["cycle"] = int(prob.num_outputs / min(num_cols, k_max))
    ififo_result = copy.deepcopy(compu_result)
    ififo_result["cycle"] = 0
    wfifo_result = copy.deepcopy(compu_result)
    leaky_result = copy.deepcopy(compu_result)
    leaky_result["cycle"] = int(prob.num_outputs / min(num_cols, k_max))
    reset_result = copy.deepcopy(compu_result)
    reset_result["cycle"] = int(prob.num_outputs / min(num_cols, k_max))

    isram_result = OrderedDict()
    isram_result["access stall"]    = OrderedDict({ "rd": int(isram_stall_rd), 
                                                    "wr": int(isram_stall_wr)})
    isram_result["byte total"]      = OrderedDict({ "rd": int(isram_byte_rd), 
                                                    "wr": int(isram_byte_wr)})
    isram_result["access total"]    = OrderedDict({ "rd": int(isram_byte_rd / isram_cnfg["byte_per_brow"]), 
                                                    "wr": int(isram_byte_wr / isram_cnfg["byte_per_brow"])})
    
    wsram_result = OrderedDict()
    wsram_result["access stall"]    = OrderedDict({ "rd": int(wsram_stall_rd), 
                                                    "wr": int(wsram_stall_wr)})
    wsram_result["byte total"]      = OrderedDict({ "rd": int(math.ceil(wsram_word_rd * wsram_cnfg["byte_per_word"])), 
                                                    "wr": int(math.ceil(wsram_word_wr * wsram_cnfg["byte_per_word"]))})
    wsram_result["access total"]    = OrderedDict({ "rd": int(math.ceil(wsram_word_rd * wsram_cnfg["byte_per_word"] / wsram_cnfg["byte_per_brow"])), 
                                                    "wr": int(math.ceil(wsram_word_wr * wsram_cnfg["byte_per_word"] / wsram_cnfg["byte_per_brow"]))})

    osram_result = OrderedDict()
    osram_result["access stall"]    = OrderedDict({ "rd": int(osram_stall_rd), 
                                                    "wr": int(osram_stall_wr)})
    osram_result["byte total"]      = OrderedDict({ "rd": int(osram_byte_rd), 
                                                    "wr": int(osram_byte_wr)})
    osram_result["access total"]    = OrderedDict({ "rd": int(osram_byte_rd / osram_cnfg["byte_per_brow"]), 
                                                    "wr": int(osram_byte_wr / osram_cnfg["byte_per_brow"])})

    usram_result = OrderedDict()
    usram_result["access stall"]    = OrderedDict({ "rd": int(usram_stall_rd), 
                                                    "wr": int(usram_stall_wr)})
    usram_result["byte total"]      = OrderedDict({ "rd": int(math.ceil(usram_word_rd * usram_cnfg["byte_per_word"])), 
                                                    "wr": int(math.ceil(usram_word_wr * usram_cnfg["byte_per_word"]))})
    usram_result["access total"]    = OrderedDict({ "rd": int(math.ceil(usram_word_rd * usram_cnfg["byte_per_word"] / usram_cnfg["byte_per_brow"])), 
                                                    "wr": int(math.ceil(usram_word_wr * usram_cnfg["byte_per_word"] / usram_cnfg["byte_per_brow"]))})

    odram_result = OrderedDict()
    odram_byte_rd = isram_byte_wr + wsram_word_wr * wsram_cnfg["byte_per_word"] + usram_odram_word_wr * usram_cnfg["byte_per_word"]
    odram_byte_wr = osram_byte_rd + usram_odram_word_rd * usram_cnfg["byte_per_word"]
    odram_result["access stall"]    = OrderedDict({ "rd": 0 if odram_byte_rd == 0 else int(math.ceil(odram_stall * odram_byte_rd / (odram_byte_rd + odram_byte_wr))),
                                                    "wr": 0 if odram_byte_wr == 0 else int(math.ceil(odram_stall * odram_byte_wr / (odram_byte_rd + odram_byte_wr)))})
    odram_result["byte total"]      = OrderedDict({ "rd": int(odram_byte_rd), 
                                                    "wr": int(odram_byte_wr)})
    odram_result["access total"]    = OrderedDict({ "rd": int(math.ceil(odram_byte_rd / odram_cnfg["bits_per_chip"])), 
                                                    "wr": int(math.ceil(odram_byte_wr / odram_cnfg["bits_per_chip"]))})
    odram_result["page total"]      = OrderedDict({ "rd": int(math.ceil(odram_byte_rd * 8 / odram_cnfg["bits_per_page"])), 
                                                    "wr": int(math.ceil(odram_byte_wr * 8 / odram_cnfg["bits_per_page"]))})

    output["compu"] = compu_result
    output["spgen"] = spgen_result
    output["sspar"] = sspar_result
    output["ispar"] = ispar_result
    output["ospar"] = ospar_result
    output["ififo"] = ififo_result
    output["wfifo"] = wfifo_result
    output["leaky"] = leaky_result
    output["reset"] = reset_result
    output["isram"] = isram_result
    output["wsram"] = wsram_result
    output["osram"] = osram_result
    output["usram"] = usram_result
    output["odram"] = odram_result

    return output

