from collections import OrderedDict
import numpy as np
from workload import spike
import math
from utils import bcolors
import copy


def dataflow(
    arch_cnfg: OrderedDict,
    prob_cnfg: OrderedDict
):
    """
    This function implements the weight stationary dataflow for TPU-like systolic arrays.
    All memories are double buffered.

    Systolic array PEs are not working, while neuron PEs are working.

    Input format: NXYC
    Output format: NXYC

    Input schedule: NXYC
    Output schedule: NXYC
    

    Channel last allows the isram and osram fetch to be sequential, reducing sram access count
    """

    # intra-image parallel with weight staionary

    compu_cnfg = arch_cnfg["compu"]
    ififo_cnfg = arch_cnfg["ififo"]
    ipack_cnfg = arch_cnfg["ipack"]
    wfifo_cnfg = arch_cnfg["wfifo"]
    opack_cnfg = arch_cnfg["opack"]
    ufifo_cnfg = arch_cnfg["ufifo"]
    uaccu_cnfg = arch_cnfg["uaccu"]
    leaky_cnfg = arch_cnfg["leaky"]
    reset_cnfg = arch_cnfg["reset"]
    isram_cnfg = arch_cnfg["isram"]
    wsram_cnfg = arch_cnfg["wsram"]
    osram_cnfg = arch_cnfg["osram"]
    usram_cnfg = arch_cnfg["usram"]
    odram_cnfg = arch_cnfg["odram"]

    num_rows = compu_cnfg["num_instances"][0]
    num_cols = compu_cnfg["num_instances"][1]

    # output shall include the ordered output of all components
    output = OrderedDict()

    # initialize the problem
    prob = spike("", prob_cnfg)

    # ddr3 bandwidth: 12.8 GB/S at 400MHz
    # ddr3 bandwidth: 32 byte per cycle
    odram_bycy = 12.8 * 10**9 / (arch_cnfg["frequency"] * 10**6)

    # isram/osram word is a spike
    # effective spike byte, for calculating required bandwidth
    eff_spike_byte_i = math.ceil(prob.C * isram_cnfg["byte_per_word"]) / prob.C
    eff_spike_byte_o = math.ceil(prob.C * osram_cnfg["byte_per_word"]) / prob.C

    # sram traffic
    # isram traffic total, input will be loaded multiple times, according to how the weight kernel is splitted.
    # this value is currently approximated. It's precise value shall be calculated based on all layer parameters, to be done later
    # with compute
    isram_word_rd = 0
    isram_bwcy_rd = isram_cnfg["byte_per_brow"] * isram_cnfg["bank_per_sram"]
    # with odram
    isram_word_wr = 0
    isram_bwcy_wr = min(odram_bycy, isram_cnfg["byte_per_brow"] * isram_cnfg["bank_per_sram"])

    # wsram traffic total, weight is stationary is the weight fifo, so transfered once
    # with compute
    wsram_word_rd = 0
    wsram_bwcy_rd = wsram_cnfg["byte_per_brow"] * wsram_cnfg["bank_per_sram"]
    # with odram
    wsram_word_wr = 0
    wsram_bwcy_wr = min(odram_bycy, wsram_cnfg["byte_per_brow"] * wsram_cnfg["bank_per_sram"])

    # osram traffic total, output is stationary is the accumulation fifo, so transfered once
    # with odram
    osram_word_rd = 0
    osram_bwcy_rd = min(odram_bycy, osram_cnfg["byte_per_brow"] * osram_cnfg["bank_per_sram"])
    # with compute
    osram_word_wr = 0
    osram_bwcy_wr = osram_cnfg["byte_per_brow"] * osram_cnfg["bank_per_sram"]
    large_osram = osram_cnfg["byte_tot_sram"] / eff_spike_byte_o / 2 >= prob.num_outputs

    # usram traffic total, output is stationary is the accumulation fifo, so transfered once
    # with odram
    usram_word_rd = 0
    usram_bwcy_rd = min(odram_bycy, usram_cnfg["byte_per_brow"] * usram_cnfg["bank_per_sram"])
    # with compute
    usram_word_wr = 0
    usram_bwcy_wr = usram_cnfg["byte_per_brow"] * usram_cnfg["bank_per_sram"]

    # l0-fifo
    # 1. analyze max tile size
    # hint: given weight stationary dataflow, the tile size is bounded by input and output buffer size.
    # insight: for any compute logic, we can specify such constraints.
    
    compute_cycle = 0
    stall_cycle = 0
    stream_in_cycle = 0
    stream_out_cycle = 0

    # assume only sram stall, i.e., fifos do not stall due to double buffer
    isram_stall_rd = 0
    isram_stall_wr = 0
    isram_flag_tile_first_wr = True

    wsram_stall_rd = 0
    wsram_stall_wr = 0
    wsram_flag_tile_first_wr = True

    osram_stall_rd = 0
    osram_stall_wr = 0
    osram_flag_tile_last_rd = True

    usram_stall_rd = 0
    usram_stall_wr = 0

    odram_stall = 0

    total_util = 0

    # time step can be parallelized to minimize the access to usram, acutally no usram access
    # max images in a batch according to osram
    n_max_o = math.floor(osram_cnfg["byte_tot_sram"] / eff_spike_byte_o / 2 / (prob.C * prob.X * prob.Y))
    n_max = max(n_max_o, 1)
    n_map = np.divmod(prob.N, n_max)

    k_map = np.divmod(prob.C, num_cols)

    for k in range(int(k_map[0]) + 1):
        map_cols = num_cols if k < k_map[0] else k_map[1]
        if map_cols == 0:
            break
        # prioritize to stream all time steps together
        tile_cycle = prob.N * prob.X * prob.Y
        compute_cycle += tile_cycle
        total_util += (num_rows + 1) * map_cols * tile_cycle

    isram_word_rd = 0
    isram_word_wr = 0

    wsram_word_rd = prob.num_inputs
    wsram_word_wr = prob.num_inputs

    osram_word_rd = prob.num_outputs
    osram_word_wr = prob.num_outputs

    usram_word_rd = prob.num_inputs * int(n_map[0] + math.ceil(n_map[1] / n_max) - 1)
    usram_word_wr = prob.num_inputs * int(n_map[0] + math.ceil(n_map[1] / n_max) - 1)

    # wsram now stores the input binary data
    wsram_word_tile_wr = min(prob.C * prob.X * prob.Y, wsram_cnfg["byte_tot_sram"] / wsram_cnfg["byte_per_word"] / 2, num_cols * prob.X * prob.Y)
    stream_in_cycle += math.ceil(wsram_word_tile_wr * wsram_cnfg["byte_per_word"] / odram_bycy)

    first_wfifo_map = wfifo_cnfg["instance"]["queue"]["depth"] + num_rows

    osram_word_tile_wr = n_max * (prob.C * prob.X * prob.Y)
    assert osram_word_tile_wr > 0, \
        bcolors.FAIL + "osram is smaller than a single input for spike generation." + bcolors.ENDC
    if large_osram is False:
        stream_out_cycle += math.ceil(osram_word_tile_wr * eff_spike_byte_o / odram_bycy)

    # count for initial streaming latency
    compute_cycle += stream_in_cycle + stream_out_cycle + first_wfifo_map

    stall_cycle = isram_stall_rd + isram_stall_wr + wsram_stall_rd + wsram_stall_wr + osram_stall_rd + osram_stall_wr + odram_stall
    utilization = total_util / (num_rows * num_cols * compute_cycle + num_cols * compute_cycle) * 100.

    compu_result = OrderedDict()
    compu_result["cycle"] = int(compute_cycle)
    compu_result["utilization"] = float(utilization)
    compu_result["flops"] = float(prob.flops)

    ififo_result = copy.deepcopy(compu_result)
    ififo_result["cycle"] = 0
    ipack_result = copy.deepcopy(compu_result)
    ipack_result["cycle"] = 0
    wfifo_result = copy.deepcopy(compu_result)
    opack_result = copy.deepcopy(compu_result)
    ufifo_result = copy.deepcopy(compu_result)
    uaccu_result = copy.deepcopy(compu_result)
    leaky_result = copy.deepcopy(compu_result)
    leaky_result["cycle"] = int(0)
    reset_result = copy.deepcopy(compu_result)

    isram_result = OrderedDict()
    isram_result["access stall"]    = OrderedDict({ "rd": int(isram_stall_rd), 
                                                    "wr": int(isram_stall_wr)})
    isram_result["byte total"]      = OrderedDict({ "rd": int(math.ceil(isram_word_rd * eff_spike_byte_i)), 
                                                    "wr": int(math.ceil(isram_word_wr * eff_spike_byte_i))})
    isram_result["access total"]    = OrderedDict({ "rd": int(math.ceil(isram_word_rd * eff_spike_byte_i / isram_cnfg["byte_per_brow"])), 
                                                    "wr": int(math.ceil(isram_word_wr * eff_spike_byte_i / isram_cnfg["byte_per_brow"]))})
    
    wsram_result = OrderedDict()
    wsram_result["access stall"]    = OrderedDict({ "rd": int(wsram_stall_rd), 
                                                    "wr": int(wsram_stall_wr)})
    wsram_result["byte total"]      = OrderedDict({ "rd": int(math.ceil(wsram_word_rd * wsram_cnfg["byte_per_word"])), 
                                                    "wr": int(math.ceil(wsram_word_wr * wsram_cnfg["byte_per_word"]))})
    wsram_result["access total"]    = OrderedDict({ "rd": int(math.ceil(wsram_word_rd * wsram_cnfg["byte_per_word"] / wsram_cnfg["byte_per_brow"])), 
                                                    "wr": int(math.ceil(wsram_word_wr * wsram_cnfg["byte_per_word"] / wsram_cnfg["byte_per_brow"]))})

    osram_result = OrderedDict()
    osram_result["access stall"]    = OrderedDict({ "rd": int(osram_stall_rd), 
                                                    "wr": int(osram_stall_wr)})
    osram_result["byte total"]      = OrderedDict({ "rd": int(math.ceil(osram_word_rd * eff_spike_byte_o)), 
                                                    "wr": int(math.ceil(osram_word_wr * eff_spike_byte_o))})
    osram_result["access total"]    = OrderedDict({ "rd": int(math.ceil(osram_word_rd * eff_spike_byte_o / osram_cnfg["byte_per_brow"])), 
                                                    "wr": int(math.ceil(osram_word_wr * eff_spike_byte_o / osram_cnfg["byte_per_brow"]))})

    usram_result = OrderedDict()
    usram_result["access stall"]    = OrderedDict({ "rd": int(usram_stall_rd), 
                                                    "wr": int(usram_stall_wr)})
    usram_result["byte total"]      = OrderedDict({ "rd": int(math.ceil(usram_word_rd * usram_cnfg["byte_per_word"])), 
                                                    "wr": int(math.ceil(usram_word_wr * usram_cnfg["byte_per_word"]))})
    usram_result["access total"]    = OrderedDict({ "rd": int(math.ceil(usram_word_rd * usram_cnfg["byte_per_word"] / usram_cnfg["byte_per_brow"])), 
                                                    "wr": int(math.ceil(usram_word_wr * usram_cnfg["byte_per_word"] / usram_cnfg["byte_per_brow"]))})

    odram_result = OrderedDict()
    odram_byte_rd = isram_word_wr * eff_spike_byte_i + wsram_word_wr * wsram_cnfg["byte_per_word"]
    odram_byte_wr = osram_word_rd * eff_spike_byte_o
    odram_result["access stall"]    = OrderedDict({ "rd": 0 if odram_byte_rd == 0 else int(math.ceil(odram_stall * odram_byte_rd / (odram_byte_rd + odram_byte_wr))),
                                                    "wr": 0 if odram_byte_wr == 0 else int(math.ceil(odram_stall * odram_byte_wr / (odram_byte_rd + odram_byte_wr)))})
    odram_result["byte total"]      = OrderedDict({ "rd": int(odram_byte_rd), 
                                                    "wr": int(odram_byte_wr)})
    odram_result["access total"]    = OrderedDict({ "rd": int(math.ceil(odram_byte_rd / odram_cnfg["bits_per_chip"])), 
                                                    "wr": int(math.ceil(odram_byte_wr / odram_cnfg["bits_per_chip"]))})
    odram_result["page total"]      = OrderedDict({ "rd": int(math.ceil(odram_byte_rd * 8 / odram_cnfg["bits_per_page"])), 
                                                    "wr": int(math.ceil(odram_byte_wr * 8 / odram_cnfg["bits_per_page"]))})

    output["compu"] = compu_result
    output["ififo"] = ififo_result
    output["ipack"] = ipack_result
    output["wfifo"] = wfifo_result
    output["opack"] = opack_result
    output["ufifo"] = ufifo_result
    output["uaccu"] = uaccu_result
    output["leaky"] = leaky_result
    output["reset"] = reset_result
    output["isram"] = isram_result
    output["wsram"] = wsram_result
    output["osram"] = osram_result
    output["usram"] = usram_result
    output["odram"] = odram_result

    return output

