from collections import OrderedDict
import numpy as np
from workload import avgpool2d
import math
from utils import bcolors
import copy


def dataflow(
    arch_cnfg: OrderedDict,
    prob_cnfg: OrderedDict
):
    """
    This function implements the weight stationary dataflow for TPU-like systolic arrays.
    All memories are double buffered.
    Input format: NXYC
    Weight format: KRSC
    Output format: NPQK

    Input schedule: NXYC
    Output schedule: NPQK

    Channel last allows the isram and osram fetch to be sequential, reducing sram access count
    This dataflow leverages bit-packing for high sram efficiency.

    This dataflow converts conv2d to matrix vector multiplication by prioritizing channel-level parallelism.
    Please refer to https://inst.eecs.berkeley.edu/~eecs151/sp20/files/lec18-DNN-QijingHuang.pdf page 29.
    """

    # intra-image parallel with weight staionary

    compu_cnfg = arch_cnfg["compu"]
    ififo_cnfg = arch_cnfg["ififo"]
    ipack_cnfg = arch_cnfg["ipack"]
    wfifo_cnfg = arch_cnfg["wfifo"]
    opack_cnfg = arch_cnfg["opack"]
    ufifo_cnfg = arch_cnfg["ufifo"]
    uaccu_cnfg = arch_cnfg["uaccu"]
    leaky_cnfg = arch_cnfg["leaky"]
    reset_cnfg = arch_cnfg["reset"]
    isram_cnfg = arch_cnfg["isram"]
    wsram_cnfg = arch_cnfg["wsram"]
    osram_cnfg = arch_cnfg["osram"]
    usram_cnfg = arch_cnfg["usram"]
    odram_cnfg = arch_cnfg["odram"]

    num_rows = compu_cnfg["num_instances"][0]
    num_cols = compu_cnfg["num_instances"][1]

    # output shall include the ordered output of all components
    output = OrderedDict()

    # initialize the problem
    prob = avgpool2d("", prob_cnfg)

    # ddr3 bandwidth: 12.8 GB/S at 400MHz
    # ddr3 bandwidth: 32 byte per cycle
    odram_bycy = 12.8 * 10**9 / (arch_cnfg["frequency"] * 10**6)

    # isram/osram word is a spike
    # effective spike byte, for calculating required bandwidth
    eff_spike_byte_i = math.ceil(prob.C * isram_cnfg["byte_per_word"]) / prob.C
    eff_spike_byte_o = math.ceil(prob.C * osram_cnfg["byte_per_word"]) / prob.C

    # sram traffic
    # with compute, odram
    isram_word_rd, isram_word_wr = 0, 0
    isram_bycy = isram_cnfg["byte_per_brow"] * isram_cnfg["bank_per_sram"]
    large_isram = isram_cnfg["byte_tot_sram"] / eff_spike_byte_i / 2 >= prob.num_inputs_valid

    # wsram traffic total
    # with odram, compute
    wsram_word_rd, wsram_word_wr = 0, 0
    wsram_bycy = wsram_cnfg["byte_per_brow"] * wsram_cnfg["bank_per_sram"]
    
    # osram traffic total, output is stationary is the accumulation fifo, so transfered once
    # with odram, compute
    osram_word_rd, osram_word_wr = 0, 0
    osram_bycy = osram_cnfg["byte_per_brow"] * osram_cnfg["bank_per_sram"]
    large_osram = osram_cnfg["byte_tot_sram"] / eff_spike_byte_o / 2 >= prob.num_outputs

    # usram traffic total, output is stationary is the accumulation fifo, so transfered once
    # with odram, compute
    usram_word_rd, usram_word_wr = 0, 0
    usram_bycy = usram_cnfg["byte_per_brow"] * usram_cnfg["bank_per_sram"]
    large_usram = usram_cnfg["byte_tot_sram"] / usram_cnfg["byte_per_word"] / 2 >= (num_cols * prob.P * prob.Q)
    usram_odram_word_rd, usram_odram_word_wr = 0, 0

    # l0-fifo
    # 1. analyze max tile size
    # hint: given weight stationary dataflow, the tile size is bounded by input and output buffer size.
    # insight: for any compute logic, we can specify such constraints.

    # 1.1 calculate npq tile size based on ufifo size
    assert ufifo_cnfg["num_instances"][0] == 2, \
        bcolors.FAIL + "The output fifo is not double buffered." + bcolors.ENDC
    assert ufifo_cnfg["num_instances"][1] == num_cols, \
        bcolors.FAIL + "The output fifo count is uneuqal to the array column count." + bcolors.ENDC
    # max output can be hold for this tile due to limited ufifo depth
    ts_u_bd = ufifo_cnfg["instance"]["queue"]["depth"]

    # 1.2 calculate npq tile size based on ififo size
    assert ififo_cnfg["num_instances"][1] == 2, \
        bcolors.FAIL + "The input fifo is not double buffered." + bcolors.ENDC
    assert ififo_cnfg["num_instances"][0] == num_rows, \
        bcolors.FAIL + "The input fifo count is unequal to the array row count." + bcolors.ENDC
    # max output can be hold for this tile due to limited ififo depth
    ts_i_bd = ififo_cnfg["instance"]["queue"]["depth"]

    # 1.3 calculate crs tile size based on wfifo size
    assert wfifo_cnfg["num_instances"][0] == 2, \
        bcolors.FAIL + "The weight fifo is not double buffered." + bcolors.ENDC
    assert wfifo_cnfg["num_instances"][1] == num_cols, \
        bcolors.FAIL + "The weight fifo count is unequal to the array column count." + bcolors.ENDC
    # max output can be hold for this tile due to limited wfifo depth
    ts_w_bd = wfifo_cnfg["instance"]["queue"]["depth"]

    # This is the actual tile size that bounded by both input and output buffer size.
    # Weight buffer size does not impact it as weights are preloaded.
    ts_compute = ts_u_bd

    w_sz = prob.C * prob.R * prob.S
    # max k value bounded by shape of systolic array when acting as a transposer
    k_max = min(num_rows, num_cols)
    # assuming always fetch all input from dram, stride or dilution only happends onchip
    # SNN can not leverage valid input loading, due to fragment memroy access
    
    n_max_i = math.floor(isram_cnfg["byte_tot_sram"] / eff_spike_byte_i / 2 / (prob.C * prob.Y_valid) / prob.R)
    n_max_o = math.floor(osram_cnfg["byte_tot_sram"] / eff_spike_byte_o / 2 / (num_cols * prob.Q))
    n_max_u = math.floor(usram_cnfg["byte_tot_sram"] / usram_cnfg["byte_per_word"] / 2 / (num_cols * prob.Q))
    n_max = max(min(n_max_i, n_max_o, n_max_u, prob.N), 1)

    p_max_i = math.floor(isram_cnfg["byte_tot_sram"] / eff_spike_byte_i / 2 / (prob.C * prob.Y_valid * n_max))
    p_max_o = math.floor(osram_cnfg["byte_tot_sram"] / eff_spike_byte_o / 2 / (num_cols * prob.Q * n_max))
    p_max_u = math.floor(usram_cnfg["byte_tot_sram"] / usram_cnfg["byte_per_word"] / 2 / (num_cols * prob.Q))
    p_max = max(min(p_max_i, p_max_o, p_max_u, prob.P), 1)
    
    n_map = np.divmod(prob.N, n_max)
    p_map = np.divmod(prob.P, p_max)
    
    compute_cycle = 0
    stall_cycle = 0
    stream_in_cycle = 0
    stream_out_cycle = 0

    # assume only sram stall, i.e., fifos do not stall due to double buffer
    isram_stall_rd, isram_stall_wr = 0, 0
    isram_flag_tile_first_wr = True

    wsram_stall_rd, wsram_stall_wr = 0, 0
    wsram_flag_tile_first_wr = True

    osram_stall_rd, osram_stall_wr = 0, 0

    usram_stall_rd, usram_stall_wr = 0, 0

    odram_stall = 0

    first_map_cols = 0
    first_map_flag = True
    last_map_rows = 0
    
    total_util = 0
    
    first_ififo_map, first_wfifo_map = 0, 0
    first_ififo_flag, first_wfifo_flag = True, True


    # each K is mapped to one column
    col_map = np.divmod(prob.C, k_max)
    for k in range(int(col_map[0] + 1)):
        # each column works on one output channel, this is spatial mapping
        map_cols = num_cols if k < col_map[0] else col_map[1]
        if map_cols == 0:
                break
        if first_map_flag is True:
            first_map_cols = map_cols * 2
            first_map_flag = False

        n_current = 0
        for n in range(int(n_map[0] + 1)):
            n_now = n_max if n < n_map[0] else n_map[1]
            if n_now == 0:
                break
            n_current += n_now
            n_remain = prob.N - n_current

            for p in range(int(p_map[0] + 1)):
                p_now = p_max if p < p_map[0] else p_map[1]
                if p_now == 0:
                    break
                
                n_cycle = 0

                # becomes output stationary at this point, due to accumulation in fifo
                npq_map = np.divmod(n_now * p_now * prob.Q, ts_compute)

                X_now = (prob.Hstride * (p_now - 1) + prob.R) if prob.Hstride < prob.R else (prob.R * p_now)
                isram_word_tile_wr = 0 if large_isram else (n_now * prob.C * X_now * prob.Y_valid)
                isram_word_wr += isram_word_tile_wr

                isram_cycl_tile_wr = 0

                if isram_flag_tile_first_wr is True:
                    stream_in_cycle += math.ceil(isram_word_tile_wr * eff_spike_byte_i / min(odram_bycy, isram_bycy))
                    isram_flag_tile_first_wr = False

                osram_word_tile_rd = 0 if large_osram else (n_now * map_cols * p_now * prob.Q)
                osram_word_rd += osram_word_tile_rd
                osram_cycl_tile_rd = 0

                osram_word_tile_wr = n_now * map_cols * p_now * prob.Q
                osram_word_wr += osram_word_tile_wr
                osram_cycl_tile_wr = 0

                # if total time step is small than n_now, no need to read or write usram at all
                usram_word_tile_rd = map_cols * p_now * prob.Q * (1 if n_current == n_now else 0)
                usram_word_tile_wr = map_cols * p_now * prob.Q * (1 if n_remain > 0 else 0)

                usram_word_rd += usram_word_tile_rd * (1 if large_usram else 2)
                usram_word_wr += usram_word_tile_wr * (1 if large_usram else 2)
                usram_odram_word_rd += usram_word_tile_rd * (0 if large_usram else 1)
                usram_odram_word_wr += usram_word_tile_wr * (0 if large_usram else 1)

                usram_cycl_tile_rd = 0
                usram_cycl_tile_wr = 0

                stream_out_cycle = math.ceil(osram_word_tile_rd * eff_spike_byte_o / min(odram_bycy, osram_bycy))

                wsram_word_tile_rd = 0
                wsram_word_rd += wsram_word_tile_rd


                for npq in range(int(npq_map[0] + 1)):
                    stream_cycle = (ts_compute if npq < npq_map[0] else npq_map[1]) * prob.R * prob.S
                    if stream_cycle == 0:
                        break
                    
                    if first_ififo_flag is True:
                        first_ififo_map = math.ceil(map_cols * stream_cycle * eff_spike_byte_i / isram_bycy)
                        first_ififo_flag = False

                    # here we apply channel-level parallelism for matrix-vector multiplication, the read amount depends on the ififo size
                    # no input reuse is assumed here
                    isram_cycl_tile_rd = 0
                    # as channel dim does continuous, the word read overhead is negligible
                    isram_word_tile_rd = stream_cycle * map_cols
                    isram_word_rd += isram_word_tile_rd

                    tile_cycle = stream_cycle
                    compute_cycle += tile_cycle

                    isram_cycl_tile_rd += tile_cycle
                    isram_cycl_tile_wr += tile_cycle

                    osram_cycl_tile_rd += tile_cycle
                    osram_cycl_tile_wr += tile_cycle

                    usram_cycl_tile_rd += tile_cycle
                    usram_cycl_tile_wr += tile_cycle

                    n_cycle += tile_cycle

                    total_util += (map_cols + 1) * map_cols * tile_cycle

                    isram_stall_rd += math.ceil(max(isram_word_tile_rd * eff_spike_byte_i / isram_bycy - isram_cycl_tile_rd, 0))
                
                isram_stall_wr += math.ceil(max(isram_word_tile_wr * eff_spike_byte_i / isram_bycy - isram_cycl_tile_wr, 0))

                osram_stall_rd += math.ceil(max(osram_word_tile_rd * eff_spike_byte_o / osram_bycy - osram_cycl_tile_rd, 0))
                osram_stall_wr += math.ceil(max(osram_word_tile_wr * eff_spike_byte_o / osram_bycy - osram_cycl_tile_wr, 0))

                usram_stall_rd += math.ceil(max(usram_word_tile_rd * usram_cnfg["byte_per_word"] / usram_bycy - usram_cycl_tile_rd, 0))
                usram_stall_wr += math.ceil(max(usram_word_tile_wr * usram_cnfg["byte_per_word"] / usram_bycy - usram_cycl_tile_wr, 0))
                
                odram_byte_io = isram_word_tile_wr * eff_spike_byte_i + osram_word_tile_rd * eff_spike_byte_o
                odram_byte_u = (usram_odram_word_wr * usram_cnfg["byte_per_word"] + usram_odram_word_rd * usram_cnfg["byte_per_word"])
                odram_byte_w  = 0
                odram_stall += math.ceil(max((odram_byte_io + odram_byte_w+ odram_byte_u) / odram_bycy - n_cycle, 0))

    # count for initial streaming latency
    compute_cycle += first_map_cols + last_map_rows
    compute_cycle += first_ififo_map + first_wfifo_map
    compute_cycle += stream_in_cycle + stream_out_cycle

    utilization = total_util / (num_rows * num_cols * compute_cycle + num_cols * compute_cycle) * 100.

    compu_result = OrderedDict()
    compu_result["cycle"] = int(compute_cycle)
    compu_result["utilization"] = float(utilization)
    compu_result["flops"] = float(prob.flops)

    ififo_result = copy.deepcopy(compu_result)
    ipack_result = copy.deepcopy(compu_result)
    wfifo_result = copy.deepcopy(compu_result)
    wfifo_result["cycle"] = 0
    opack_result = copy.deepcopy(compu_result)
    opack_result["cycle"] = int(prob.num_outputs / min(num_cols, k_max))
    ufifo_result = copy.deepcopy(compu_result)
    uaccu_result = copy.deepcopy(compu_result)
    leaky_result = copy.deepcopy(compu_result)
    leaky_result["cycle"] = int(prob.num_outputs / min(num_cols, k_max))
    reset_result = copy.deepcopy(compu_result)
    reset_result["cycle"] = int(prob.num_outputs / min(num_cols, k_max))

    isram_result = OrderedDict()
    isram_result["access stall"]    = OrderedDict({ "rd": int(isram_stall_rd), 
                                                    "wr": int(isram_stall_wr)})
    isram_result["byte total"]      = OrderedDict({ "rd": int(math.ceil(isram_word_rd * eff_spike_byte_i)), 
                                                    "wr": int(math.ceil(isram_word_wr * eff_spike_byte_i))})
    isram_result["access total"]    = OrderedDict({ "rd": int(math.ceil(isram_word_rd * eff_spike_byte_i / isram_cnfg["byte_per_brow"])), 
                                                    "wr": int(math.ceil(isram_word_wr * eff_spike_byte_i / isram_cnfg["byte_per_brow"]))})
    
    wsram_result = OrderedDict()
    wsram_result["access stall"]    = OrderedDict({ "rd": int(wsram_stall_rd), 
                                                    "wr": int(wsram_stall_wr)})
    wsram_result["byte total"]      = OrderedDict({ "rd": int(math.ceil(wsram_word_rd * wsram_cnfg["byte_per_word"])), 
                                                    "wr": int(math.ceil(wsram_word_wr * wsram_cnfg["byte_per_word"]))})
    wsram_result["access total"]    = OrderedDict({ "rd": int(math.ceil(wsram_word_rd * wsram_cnfg["byte_per_word"] / wsram_cnfg["byte_per_brow"])), 
                                                    "wr": int(math.ceil(wsram_word_wr * wsram_cnfg["byte_per_word"] / wsram_cnfg["byte_per_brow"]))})

    osram_result = OrderedDict()
    osram_result["access stall"]    = OrderedDict({ "rd": int(osram_stall_rd), 
                                                    "wr": int(osram_stall_wr)})
    osram_result["byte total"]      = OrderedDict({ "rd": int(math.ceil(osram_word_rd * eff_spike_byte_o)), 
                                                    "wr": int(math.ceil(osram_word_wr * eff_spike_byte_o))})
    osram_result["access total"]    = OrderedDict({ "rd": int(math.ceil(osram_word_rd * eff_spike_byte_o / osram_cnfg["byte_per_brow"])), 
                                                    "wr": int(math.ceil(osram_word_wr * eff_spike_byte_o / osram_cnfg["byte_per_brow"]))})

    usram_result = OrderedDict()
    usram_result["access stall"]    = OrderedDict({ "rd": int(usram_stall_rd), 
                                                    "wr": int(usram_stall_wr)})
    usram_result["byte total"]      = OrderedDict({ "rd": int(math.ceil(usram_word_rd * usram_cnfg["byte_per_word"])), 
                                                    "wr": int(math.ceil(usram_word_wr * usram_cnfg["byte_per_word"]))})
    usram_result["access total"]    = OrderedDict({ "rd": int(math.ceil(usram_word_rd * usram_cnfg["byte_per_word"] / usram_cnfg["byte_per_brow"])), 
                                                    "wr": int(math.ceil(usram_word_wr * usram_cnfg["byte_per_word"] / usram_cnfg["byte_per_brow"]))})

    odram_result = OrderedDict()
    odram_byte_rd = isram_word_wr * eff_spike_byte_i + wsram_word_wr * wsram_cnfg["byte_per_word"] + usram_odram_word_wr * usram_cnfg["byte_per_word"]
    odram_byte_wr = osram_word_rd * eff_spike_byte_o + usram_odram_word_rd * usram_cnfg["byte_per_word"]
    odram_result["access stall"]    = OrderedDict({ "rd": 0 if odram_byte_rd == 0 else int(math.ceil(odram_stall * odram_byte_rd / (odram_byte_rd + odram_byte_wr))),
                                                    "wr": 0 if odram_byte_wr == 0 else int(math.ceil(odram_stall * odram_byte_wr / (odram_byte_rd + odram_byte_wr)))})
    odram_result["byte total"]      = OrderedDict({ "rd": int(odram_byte_rd), 
                                                    "wr": int(odram_byte_wr)})
    odram_result["access total"]    = OrderedDict({ "rd": int(math.ceil(odram_byte_rd / odram_cnfg["bits_per_chip"])), 
                                                    "wr": int(math.ceil(odram_byte_wr / odram_cnfg["bits_per_chip"]))})
    odram_result["page total"]      = OrderedDict({ "rd": int(math.ceil(odram_byte_rd * 8 / odram_cnfg["bits_per_page"])), 
                                                    "wr": int(math.ceil(odram_byte_wr * 8 / odram_cnfg["bits_per_page"]))})

    output["compu"] = compu_result
    output["ififo"] = ififo_result
    output["ipack"] = ipack_result
    output["wfifo"] = wfifo_result
    output["opack"] = opack_result
    output["ufifo"] = ufifo_result
    output["uaccu"] = uaccu_result
    output["leaky"] = leaky_result
    output["reset"] = reset_result
    output["isram"] = isram_result
    output["wsram"] = wsram_result
    output["osram"] = osram_result
    output["usram"] = usram_result
    output["odram"] = odram_result

    return output

