from collections import OrderedDict
import numpy as np
from workload import spike
import math
from utils import bcolors
import copy


def dataflow(
    arch_cnfg: OrderedDict,
    prob_cnfg: OrderedDict
):
    """
    All memories are double buffered.
    Channel-wise data schedule as in SpinalFlow paper.
    CSR format for sparse spike encoding/decoding
    """

    compu_cnfg = arch_cnfg["compu"]
    spgen_cnfg = arch_cnfg["spgen"]
    sspar_cnfg = arch_cnfg["sspar"]
    ispar_cnfg = arch_cnfg["ispar"]
    ospar_cnfg = arch_cnfg["ospar"]
    ififo_cnfg = arch_cnfg["ififo"]
    wfifo_cnfg = arch_cnfg["wfifo"]
    leaky_cnfg = arch_cnfg["leaky"]
    reset_cnfg = arch_cnfg["reset"]
    isram_cnfg = arch_cnfg["isram"]
    wsram_cnfg = arch_cnfg["wsram"]
    osram_cnfg = arch_cnfg["osram"]
    usram_cnfg = arch_cnfg["usram"]
    odram_cnfg = arch_cnfg["odram"]

    num_rows = compu_cnfg["num_instances"][0]
    num_cols = compu_cnfg["num_instances"][1]

    # output shall include the ordered output of all components
    output = OrderedDict()

    # initialize the problem
    prob = spike("", prob_cnfg)

    # ddr3 bandwidth: 12.8 GB/S at 400MHz
    # ddr3 bandwidth: 32 byte per cycle
    odram_bycy = 12.8 * 10**9 / (arch_cnfg["frequency"] * 10**6)

    # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 
    # define sparse format requirement
    # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 
    # isram/osram word is a spine index and a channel index
    # effective spike byte, for calculating required bandwidth
    spike_byte_spn = 3
    spike_byte_chn = 2
    spn_access_per_spn = 2

    # sram traffic
    # with compute, odram
    isram_byte_rd, isram_byte_wr = 0, 0
    isram_bycy = isram_cnfg["byte_per_brow"] * isram_cnfg["bank_per_sram"]
    prob_byte_o = prob.num_outputs * (1 - prob.sparsity) * spike_byte_chn + prob.N * (prob.X * prob.Y + 1) * spike_byte_spn
    large_isram = isram_cnfg["byte_tot_sram"] / 2 >= prob_byte_o

    # wsram traffic total
    # with odram, compute
    wsram_word_rd, wsram_word_wr = 0, 0
    
    # osram traffic total, output is stationary is the accumulation fifo, so transfered once
    # with odram, compute
    osram_byte_rd, osram_byte_wr = 0, 0

    # usram traffic total, output is stationary is the accumulation fifo, so transfered once
    # with odram, compute
    usram_word_rd, usram_word_wr = 0, 0

    num_rngs = spgen_cnfg["num_instances"][0]
    # assert num_rngs == num_cols, bcolors.FAIL + "Unequal number of rngs and cols." + bcolors.ENDC

    # l0-fifo
    # 1. analyze max tile size
    # hint: given weight stationary dataflow, the tile size is bounded by input and output buffer size.
    # insight: for any compute logic, we can specify such constraints.
    
    compute_cycle = 0
    stream_in_cycle = 0
    stream_out_cycle = 0

    # assume only sram stall, i.e., fifos do not stall due to double buffer
    isram_stall_rd, isram_stall_wr = 0, 0
    wsram_stall_rd, wsram_stall_wr = 0, 0
    osram_stall_rd, osram_stall_wr = 0, 0
    usram_stall_rd, usram_stall_wr = 0, 0
    odram_stall = 0
    total_util = 0

    # max images in a batch according to osram
    n_max_i = math.floor(isram_cnfg["byte_tot_sram"] / 2 / (prob.C * prob.X * prob.Y * (1 - prob.sparsity) * spike_byte_chn + (prob.X * prob.Y + 1) * spike_byte_spn))
    n_max = max(n_max_i, 1)
    n_map = np.divmod(prob.N, n_max)

    k_map = np.divmod(prob.C, num_rngs)

    for n in range(prob.N):
        # sparse format can't process time step continuously, as size of each timestep is unknown
        for xy in range(prob.X * prob.Y):
            # cycle for spine index
            tile_cycle = 1
            for k in range(int(k_map[0]) + 1):
                map_cols = num_rngs if k < k_map[0] else k_map[1]
                if map_cols == 0:
                    break
                
                # cycle for channel index of this mapping
                tile_cycle += map_cols * (1 - prob.sparsity)
                compute_cycle += tile_cycle
                total_util += map_cols * tile_cycle
                tile_cycle = 1

    osram_byte_rd = 0
    osram_byte_wr = 0

    wsram_word_rd = 0
    wsram_word_wr = 0

    isram_byte_rd = 0
    isram_byte_wr = prob_byte_o

    usram_word_rd = 0
    usram_word_wr = 0

    if large_isram is False:
        stream_out_cycle += math.ceil(n_max * prob_byte_o / prob.N / odram_bycy)
        isram_byte_rd += math.ceil(n_max * prob_byte_o / prob.N) * int(n_map[0] + math.ceil(n_map[1] / n_max))

    # count for initial streaming latency
    compute_cycle += stream_in_cycle + stream_out_cycle

    stall_cycle = isram_stall_rd + isram_stall_wr + wsram_stall_rd + wsram_stall_wr + osram_stall_rd + osram_stall_wr + odram_stall
    utilization = total_util / (num_rngs * compute_cycle) * 100.

    compu_result = OrderedDict()
    compu_result["cycle"] = int(compute_cycle)
    compu_result["utilization"] = float(utilization)
    compu_result["flops"] = float(prob.flops)

    spgen_result = copy.deepcopy(compu_result)
    spgen_result["cycle"] = prob.num_outputs / num_rngs
    sspar_result = copy.deepcopy(compu_result)
    ispar_result = copy.deepcopy(compu_result)
    ispar_result["cycle"] = 0
    ospar_result = copy.deepcopy(compu_result)
    ospar_result["cycle"] = 0
    ififo_result = copy.deepcopy(compu_result)
    ififo_result["cycle"] = 0
    wfifo_result = copy.deepcopy(compu_result)
    wfifo_result["cycle"] = 0
    leaky_result = copy.deepcopy(compu_result)
    leaky_result["cycle"] = 0
    reset_result = copy.deepcopy(compu_result)
    reset_result["cycle"] = 0

    isram_result = OrderedDict()
    isram_result["access stall"]    = OrderedDict({ "rd": int(isram_stall_rd), 
                                                    "wr": int(isram_stall_wr)})
    isram_result["byte total"]      = OrderedDict({ "rd": int(math.ceil(isram_byte_rd)), 
                                                    "wr": int(math.ceil(isram_byte_wr))})
    isram_result["access total"]    = OrderedDict({ "rd": int(math.ceil(isram_byte_rd / isram_cnfg["byte_per_brow"])), 
                                                    "wr": int(math.ceil(isram_byte_wr / isram_cnfg["byte_per_brow"]))})
    
    wsram_result = OrderedDict()
    wsram_result["access stall"]    = OrderedDict({ "rd": int(wsram_stall_rd), 
                                                    "wr": int(wsram_stall_wr)})
    wsram_result["byte total"]      = OrderedDict({ "rd": int(math.ceil(wsram_word_rd * wsram_cnfg["byte_per_word"])), 
                                                    "wr": int(math.ceil(wsram_word_wr * wsram_cnfg["byte_per_word"]))})
    wsram_result["access total"]    = OrderedDict({ "rd": int(math.ceil(wsram_word_rd * wsram_cnfg["byte_per_word"] / wsram_cnfg["byte_per_brow"])), 
                                                    "wr": int(math.ceil(wsram_word_wr * wsram_cnfg["byte_per_word"] / wsram_cnfg["byte_per_brow"]))})

    osram_result = OrderedDict()
    osram_result["access stall"]    = OrderedDict({ "rd": int(osram_stall_rd), 
                                                    "wr": int(osram_stall_wr)})
    osram_result["byte total"]      = OrderedDict({ "rd": int(math.ceil(osram_byte_rd)), 
                                                    "wr": int(math.ceil(osram_byte_wr))})
    osram_result["access total"]    = OrderedDict({ "rd": int(math.ceil(osram_byte_rd / osram_cnfg["byte_per_brow"])), 
                                                    "wr": int(math.ceil(osram_byte_wr / osram_cnfg["byte_per_brow"]))})

    usram_result = OrderedDict()
    usram_result["access stall"]    = OrderedDict({ "rd": int(usram_stall_rd), 
                                                    "wr": int(usram_stall_wr)})
    usram_result["byte total"]      = OrderedDict({ "rd": int(math.ceil(usram_word_rd * usram_cnfg["byte_per_word"])), 
                                                    "wr": int(math.ceil(usram_word_wr * usram_cnfg["byte_per_word"]))})
    usram_result["access total"]    = OrderedDict({ "rd": int(math.ceil(usram_word_rd * usram_cnfg["byte_per_word"] / usram_cnfg["byte_per_brow"])), 
                                                    "wr": int(math.ceil(usram_word_wr * usram_cnfg["byte_per_word"] / usram_cnfg["byte_per_brow"]))})

    odram_result = OrderedDict()
    odram_byte_rd = prob.num_inputs * wsram_cnfg["byte_per_word"]
    odram_byte_wr = isram_byte_rd
    odram_result["access stall"]    = OrderedDict({ "rd": 0 if odram_byte_rd == 0 else int(math.ceil(odram_stall * odram_byte_rd / (odram_byte_rd + odram_byte_wr))),
                                                    "wr": 0 if odram_byte_wr == 0 else int(math.ceil(odram_stall * odram_byte_wr / (odram_byte_rd + odram_byte_wr)))})
    odram_result["byte total"]      = OrderedDict({ "rd": int(odram_byte_rd), 
                                                    "wr": int(odram_byte_wr)})
    odram_result["access total"]    = OrderedDict({ "rd": int(math.ceil(odram_byte_rd / odram_cnfg["bits_per_chip"])), 
                                                    "wr": int(math.ceil(odram_byte_wr / odram_cnfg["bits_per_chip"]))})
    odram_result["page total"]      = OrderedDict({ "rd": int(math.ceil(odram_byte_rd * 8 / odram_cnfg["bits_per_page"])), 
                                                    "wr": int(math.ceil(odram_byte_wr * 8 / odram_cnfg["bits_per_page"]))})

    output["compu"] = compu_result
    output["spgen"] = spgen_result
    output["sspar"] = sspar_result
    output["ispar"] = ispar_result
    output["ospar"] = ospar_result
    output["ififo"] = ififo_result
    output["wfifo"] = wfifo_result
    output["leaky"] = leaky_result
    output["reset"] = reset_result
    output["isram"] = isram_result
    output["wsram"] = wsram_result
    output["osram"] = osram_result
    output["usram"] = usram_result
    output["odram"] = odram_result

    return output

