from collections import OrderedDict
import numpy as np
from workload import conv
import math
from utils import bcolors
import copy


def dataflow(
    arch_cfg: OrderedDict,
    prob_cfg: OrderedDict
):
    """
    This function implements multiple nodes of the weight stationary systolic array.
    All memories are double buffered.
    
    Tensor format:
    
    Tensorflow-preferred format: (systolic arry)
    Input format: NXYC
    Weight format: RSCK
    Output format: NPQK

    PyTorch-preferred format: (SIMD)
    Input format: NCXY
    Weight format: KCRS
    Output format: NKPQ

    Please refer to https://inst.eecs.berkeley.edu/~eecs151/sp20/files/lec18-DNN-QijingHuang.pdf page 29 for how to schedule GEMM.
    """

    # intra-image parallel with weight staionary

    compu_cfg = arch_cfg["compu"]
    ififo_cfg = arch_cfg["ififo"]
    wfifo_cfg = arch_cfg["wfifo"]
    ofifo_cfg = arch_cfg["ofifo"]
    isram_cfg = arch_cfg["isram"]
    wsram_cfg = arch_cfg["wsram"]
    osram_cfg = arch_cfg["osram"]
    odram_cfg = arch_cfg["odram"]

    num_rows = compu_cfg["num_instances"][0]
    num_cols = compu_cfg["num_instances"][1]

    # configuration of NoC
    # the rows and cols define the repetition of single node carat
    # it also determines the split of input, weight and output
    noc_cfg = arch_cfg["noc"]
    noc_bw= noc_cfg["instance"]["noc"]["bandwidth"] * 10**9 / (arch_cfg["frequency"] * 10**6)
    noc_rows = noc_cfg["num_instances"][0]
    noc_cols = noc_cfg["num_instances"][1]
    num_nodes = noc_rows * noc_cols

    # output shall include the ordered output of all components
    output = OrderedDict()

    # initialize the problem
    prob = conv("", prob_cfg)

    # dram bandwidth in byte per cycle
    odram_bw = arch_cfg["odram"]["bandwidth"] * 10**9 / (arch_cfg["frequency"] * 10**6)

    # input traffic, always read from odram
    isram_word_rd = 0

    # weight traffic, always read from odram
    wsram_word_rd = 0
    
    # output traffic, always write to odram
    osram_word_wr = 0

    # analyze tile size based on buffer size
    # 1.1 calculate npq tile size based on ofifo size
    assert ofifo_cfg["num_instances"][0] == 2, \
        bcolors.FAIL + "The output fifo is not double buffered." + bcolors.ENDC
    assert ofifo_cfg["num_instances"][1] == num_cols, \
        bcolors.FAIL + "The output fifo count is unequal to the array column count." + bcolors.ENDC
    # max output can be hold for this tile due to limited ofifo depth
    ts_o_bd = ofifo_cfg["instance"]["queue"]["depth"]

    # 1.2 calculate npq tile size based on ififo size
    assert ififo_cfg["num_instances"][1] == 2, \
        bcolors.FAIL + "The input fifo is not double buffered." + bcolors.ENDC
    assert ififo_cfg["num_instances"][0] == num_rows, \
        bcolors.FAIL + "The input fifo count is unequal to the array row count." + bcolors.ENDC
    # max output can be hold for this tile due to limited ififo depth
    ts_i_bd = ififo_cfg["instance"]["queue"]["depth"]

    # 1.3 calculate crs tile size based on wfifo size
    assert wfifo_cfg["num_instances"][0] == 2, \
        bcolors.FAIL + "The weight fifo is not double buffered." + bcolors.ENDC
    assert wfifo_cfg["num_instances"][1] == num_cols, \
        bcolors.FAIL + "The weight fifo count is unequal to the array column count." + bcolors.ENDC
    # max output can be hold for this tile due to limited wfifo depth
    ts_w_bd = wfifo_cfg["instance"]["queue"]["depth"]

    # This is the actual tile size
    ts_compute = ts_o_bd

    # crs are not factorized, i.e., each tile has continuous crs index
    crs_map = np.divmod(prob.C * prob.R * prob.S, min(num_rows, ts_w_bd))

    # weight stationary forces kcrs as outer loop
    w_sz = prob.C * prob.R * prob.S
    # assume enough odram size
    k_max = prob.K
    w_map = np.divmod(prob.K, k_max)

    n_max = prob.N / num_nodes
    n_map = np.divmod(prob.N / num_nodes, n_max)

    p_max = prob.P
    p_map = np.divmod(prob.P, p_max)
    
    compute_cycle = 0
    stall_cycle = 0

    odram_stall = 0

    first_map_cols = 0
    first_map_flag = True
    last_map_rows = 0

    first_ififo_map, first_wfifo_map = 0, 0
    first_ififo_flag, first_wfifo_flag = True, True

    # weight stationary forces kcrs as outer loop
    for w in range(int(w_map[0] + 1)):
        # weight from odram to wsram
        k_wsram = k_max if w < w_map[0] else w_map[1]
        if k_wsram == 0:
            break
        wsram_cycl_tile_wr = 0

        odram_byte_io = 0
        noc_byte_w = 0
        # each K is mapped to one column
        col_map = np.divmod(k_wsram, num_cols)
        for k in range(int(col_map[0] + 1)):
            # each column works on one output channel, this is spatial mapping
            map_cols = num_cols if k < col_map[0] else col_map[1]
            if map_cols == 0:
                    break

            if first_map_flag is True:
                first_map_cols = map_cols
                first_map_flag = False

            for n in range(int(n_map[0] + 1)):
                n_now = n_max if n < n_map[0] else n_map[1]
                if n_now == 0:
                    break

                for p in range(int(p_map[0] + 1)):
                    p_now = p_max if p < p_map[0] else p_map[1]
                    if p_now == 0:
                        break
                    
                    # becomes output stationary at this point, due to accumulation in fifo
                    npq_map = np.divmod(n_now * p_now * prob.Q, ts_compute)

                    osram_word_tile_wr = n_now * map_cols * p_now * prob.Q
                    osram_word_wr += osram_word_tile_wr

                    wsram_word_tile_rd = w_sz * map_cols * (1 if (ts_w_bd * 2 >= w_sz) else int(npq_map[0] + math.ceil(npq_map[1] / ts_compute)))
                    wsram_word_rd += wsram_word_tile_rd
                    wsram_cycl_tile_rd = 0

                    for npq in range(int(npq_map[0] + 1)):
                        stream_cycle = (ts_compute if npq < npq_map[0] else npq_map[1]) * prob_cfg["cycle"]
                        if stream_cycle == 0:
                            break
                        
                        # need to decide how many q and how many n
                        # only when prioritizing Q, input can be maximally reused in an im2col manner
                        # for ANN, prioritizing Q leads to dividing by S
                        # isram_word_tile_rd = stream_cycle * prob.C * prob.R * prob.S / prob.S
                        isram_word_tile_rd = stream_cycle * prob.C * prob.R * prob.S
                        isram_word_rd += isram_word_tile_rd

                        for crs in range(int(crs_map[0] + 1)):
                            # crs are always continuously scheduled
                            map_rows = min(num_rows, ts_w_bd) if crs < crs_map[0] else crs_map[1]
                            if map_rows == 0:
                                break

                            last_map_rows = map_rows

                            if first_ififo_flag is True:
                                first_ififo_map = math.ceil(map_rows * min(stream_cycle, ts_i_bd) * isram_cfg["byte_per_word"] / odram_bw)
                                first_ififo_flag = False
                            
                            if first_wfifo_flag is True:
                                first_wfifo_map = math.ceil(map_rows * map_cols * wsram_cfg["byte_per_word"] / odram_bw)
                                first_wfifo_flag = False

                            # compute cycle, this value considers the cycles for both streaming in and out the data, however, map_cols can be perfectly hided
                            # tile_cycle = stream_cycle + map_rows + map_cols
                            tile_cycle = stream_cycle
                            compute_cycle += tile_cycle

                            wsram_cycl_tile_rd += tile_cycle
                            wsram_cycl_tile_wr += tile_cycle

                        odram_byte_io += isram_word_tile_rd * isram_cfg["byte_per_word"]
                    
                    # multiple by 2 due to output has to be read out from memory as well
                    odram_byte_io += wsram_word_tile_rd * wsram_cfg["byte_per_word"] + osram_word_tile_wr * osram_cfg["byte_per_word"] * 2
                    noc_byte_w += wsram_word_tile_rd * wsram_cfg["byte_per_word"]

        odram_stall += math.ceil(max(odram_byte_io / odram_bw + noc_byte_w * (num_nodes - 1) / noc_bw / num_nodes - wsram_cycl_tile_wr, 0))
        
    # count for initial streaming latency
    compute_cycle += first_map_cols + last_map_rows
    compute_cycle += first_ififo_map + first_wfifo_map

    utilization = prob.flops * prob_cfg["cycle"] / (num_rows * num_cols * compute_cycle * 2 + num_cols * compute_cycle) * 100. / num_nodes

    compu_result = OrderedDict()
    compu_result["cycle"] = int(compute_cycle)
    compu_result["utilization"] = float(utilization)
    compu_result["flops"] = float(prob.flops)

    noc_result = copy.deepcopy(compu_result)
    noc_result["cycle"] = int(wsram_word_rd * wsram_cfg["byte_per_word"] * (num_nodes - 1) / noc_bw / num_nodes)
    
    ififo_result = copy.deepcopy(compu_result)
    wfifo_result = copy.deepcopy(compu_result)
    ofifo_result = copy.deepcopy(compu_result)

    isram_result = OrderedDict()
    isram_result["access stall"]    = OrderedDict({ "rd": int(0), 
                                                    "wr": int(0)})
    isram_result["byte total"]      = OrderedDict({ "rd": int(math.ceil(isram_word_rd * isram_cfg["byte_per_word"] * num_nodes)), 
                                                    "wr": int(1)})
    isram_result["access total"]    = OrderedDict({ "rd": int(1), 
                                                    "wr": int(1)})
    
    wsram_result = OrderedDict()
    wsram_result["access stall"]    = OrderedDict({ "rd": int(0), 
                                                    "wr": int(0)})
    wsram_result["byte total"]      = OrderedDict({ "rd": int(math.ceil(wsram_word_rd * wsram_cfg["byte_per_word"])), 
                                                    "wr": int(1)})
    wsram_result["access total"]    = OrderedDict({ "rd": int(1), 
                                                    "wr": int(1)})

    osram_result = OrderedDict()
    osram_result["access stall"]    = OrderedDict({ "rd": int(0), 
                                                    "wr": int(0)})
    osram_result["byte total"]      = OrderedDict({ "rd": int(math.ceil(osram_word_wr * osram_cfg["byte_per_word"] * num_nodes)), 
                                                    "wr": int(math.ceil(osram_word_wr * osram_cfg["byte_per_word"] * num_nodes))})
    osram_result["access total"]    = OrderedDict({ "rd": int(1), 
                                                    "wr": int(1)})

    odram_result = OrderedDict()
    odram_byte_rd = wsram_word_rd * wsram_cfg["byte_per_word"] + isram_word_rd * isram_cfg["byte_per_word"] * num_nodes
    odram_byte_wr = osram_word_wr * osram_cfg["byte_per_word"] * num_nodes * 2
    odram_result["access stall"]    = OrderedDict({ "rd": 0 if odram_byte_rd == 0 else int(math.ceil(odram_stall * odram_byte_rd / (odram_byte_rd + odram_byte_wr))),
                                                    "wr": 0 if odram_byte_wr == 0 else int(math.ceil(odram_stall * odram_byte_wr / (odram_byte_rd + odram_byte_wr)))})
    odram_result["byte total"]      = OrderedDict({ "rd": int(odram_byte_rd), 
                                                    "wr": int(odram_byte_wr)})
    odram_result["access total"]    = OrderedDict({ "rd": int(math.ceil(odram_byte_rd / odram_cfg["bits_per_chip"])), 
                                                    "wr": int(math.ceil(odram_byte_wr / odram_cfg["bits_per_chip"]))})
    odram_result["page total"]      = OrderedDict({ "rd": int(math.ceil(odram_byte_rd * 8 / odram_cfg["bits_per_page"])), 
                                                    "wr": int(math.ceil(odram_byte_wr * 8 / odram_cfg["bits_per_page"]))})

    output["compu"] = compu_result
    output["noc"] = noc_result
    output["ififo"] = ififo_result
    output["wfifo"] = wfifo_result
    output["ofifo"] = ofifo_result
    output["oaccu"] = ofifo_result
    output["isram"] = isram_result
    output["wsram"] = wsram_result
    output["osram"] = osram_result
    output["odram"] = odram_result

    return output

