from collections import OrderedDict
import numpy as np
from workload import conv
import math
from utils import bcolors, round2power
import copy


def dataflow(
    arch_cfg: OrderedDict,
    prob_cfg: OrderedDict
):
    """
    This function implements the multi-node cris architecture from "Computation Reuse in DNNs by Exploiting Input Similarity"
    The compute kernel is a 1-d vector array, and the vector dimension maps K.
    At every cycle, one input is multiplied with vector size weights.
    Input/output SRAMs are double buffered, weight eDRAM is not double buffered.
    The noc is a ring, and output channels mapped to each mac.

    Dataflow:
    Output stationary at vector compute kernel level.
    Weight stationary at on-chip memory level.
    
    Tensor format:
    Input format: XYCN or NXYC
    Weight format: RSCK
    Output format: PQKN or NPQK
    """

    # inter-image with weight staionary

    compu_cfg = arch_cfg["compu"]

    num_rows = compu_cfg["num_instances"][0]
    num_cols = compu_cfg["num_instances"][1]

    # configuration of NoC
    # the rows and cols define the repetition of single node carat
    # it also determines the split of input, weight and output
    noc_cfg = arch_cfg["noc"]
    noc_bw= noc_cfg["instance"]["noc"]["bandwidth"] * 10**9 / (arch_cfg["frequency"] * 10**6)
    noc_rows = noc_cfg["num_instances"][0]
    noc_cols = noc_cfg["num_instances"][1]
    num_nodes = noc_rows * noc_cols

    # output shall include the ordered output of all components
    output = OrderedDict()

    # initialize the problem
    prob = conv("", prob_cfg)

    similarity = (1 + (prob.N - 1) * (1 - prob.similarity)) / prob.N

    # analyze tile size based on buffer size

    osram_cfg = arch_cfg["osram"]
    odram_cfg = arch_cfg["odram"]

    # dram bandwidth in byte per cycle
    odram_bw = arch_cfg["odram"]["bandwidth"] * 10**9 / (arch_cfg["frequency"] * 10**6)

    # osram traffic total, output is stationary is the accumulation fifo, so transfered once
    # with odram, compute
    osram_word_rd, osram_word_wr = 0, 0
    # osram_bw = osram_cfg["byte_per_brow"] * osram_cfg["bank_per_sram"]
    osram_bw = osram_cfg["byte_per_word"] * num_cols / prob_cfg["cycle"]

    # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 
    # k is mapped to width, i.e., num_cols
    # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 

    # in this schedule, we assume isram and wsram are identical, same data width, same total size
    isram_cfg = arch_cfg["isram"]
    wsram_cfg = arch_cfg["wsram"]
    
    # sram traffic
    # with compute, odram
    isram_word_rd, isram_word_wr = 0, 0
    # isram_bw = isram_cfg["byte_per_brow"] * isram_cfg["bank_per_sram"]
    isram_bw = isram_cfg["byte_per_word"] * num_rows / prob_cfg["cycle"]

    # wsram traffic total
    # with odram, compute
    wsram_word_rd, wsram_word_wr = 0, 0
    # wsram_bw = wsram_cfg["byte_per_brow"] * wsram_cfg["bank_per_sram"]
    wsram_bw = wsram_cfg["byte_per_word"] * num_cols / prob_cfg["cycle"]
    
    compute_cycle = 0

    noc_bytes = 0

    # assume only sram stall, i.e., fifos do not stall due to double buffer
    isram_stall_rd, isram_stall_wr = 0, 0
    wsram_stall_rd, wsram_stall_wr = 0, 0
    osram_stall_rd, osram_stall_wr = 0, 0
    odram_stall = 0

    # search k_max
    # we give 10% of isram and osram to store the quantized index.
    index_ratio = 0.0
    isram_word = isram_cfg["byte_tot_sram"] / isram_cfg["byte_per_word"] / 2
    isram_word = isram_word / (1 + index_ratio)
    
    wsram_word = wsram_cfg["byte_tot_sram"] / wsram_cfg["byte_per_word"] / 2
    
    osram_word = osram_cfg["byte_tot_sram"] / osram_cfg["byte_per_word"] / 2
    osram_word = osram_word / (1 + index_ratio / (osram_cfg["byte_per_word"] / isram_cfg["byte_per_word"]))
    
    total_util_max = 0

    n_max_util = prob.N
    k_max_util = prob.K
    c_max_util = prob.C
    p_max_util = prob.P
    q_max_util = prob.Q
    x_max_util = prob.X
    y_max_util = prob.Y

    success_flag = False

    # this tile search prioritizes the compute utilization, and it also checks memory utilization
    # prefer to large K and N, and small C
    max_k_per_node = max(math.floor(prob.K / num_nodes), 1)
    k_appr = round2power(max_k_per_node, mode='ceil')
    k_loops = int(math.ceil(math.log(max_k_per_node, 2))) + 1
    for k_idx in range(k_loops):
        c_appr_max = math.floor(wsram_cfg["byte_tot_sram"] / wsram_cfg["byte_per_word"] / 2 / k_appr / (prob.R * prob.S))
        c_appr_max = round2power(min(c_appr_max, prob.C))
        c_appr = 1
        c_loops = int(math.ceil(math.log(c_appr_max / c_appr, 2))) + 1
        for c_idx in range(c_loops):
            n_appr = round2power(prob.N)
            n_loops = int(math.ceil(math.log(prob.N, 2))) + 1
            for n_idx in range(n_loops):
                pq_appr = osram_cfg["byte_tot_sram"] / osram_cfg["byte_per_word"] / 2 / n_appr / k_appr
                if pq_appr < 1:
                    continue
                pq_appr = round2power(min(pq_appr, prob.P * prob.Q))
                # assume pq is square
                p_appr = round2power(min(math.sqrt(pq_appr), prob.P))
                q_appr = round2power(min(pq_appr / p_appr, prob.Q))

                xy_appr = isram_cfg["byte_tot_sram"] / isram_cfg["byte_per_word"] / 2 / n_appr / c_appr
                if xy_appr < 1:
                    continue
                xy_appr = min(xy_appr, prob.X * prob.Y)
                
                # array utilization
                # k map to col
                k2col_col_ratio = (max_k_per_node / 2**k_idx) / num_cols
                k2col_col_util = k2col_col_ratio / math.ceil(k2col_col_ratio)

                # # c to col
                # k2col_row_ratio = (prob.C / 2**c_idx) / num_rows
                # k2col_row_util = k2col_row_ratio / math.ceil(k2col_row_ratio)
                # array_util = k2col_col_util * k2col_row_util
                array_util = k2col_col_util

                # memory utilization
                nxyc_appr = n_appr * xy_appr * c_appr
                rsck_appr = prob.R * prob.S * c_appr * k_appr
                npqk_appr = n_appr * pq_appr * k_appr

                isram_util = nxyc_appr / isram_word
                wsram_util = rsck_appr / wsram_word
                osram_util = npqk_appr / osram_word

                # isram should hold at least one XY
                isram_util_bad = (isram_util == 0)
                wsram_util_bad = (wsram_util > 1) or (wsram_util == 0)
                osram_util_bad = (osram_util > 1) or (osram_util == 0)
                xsram_util_bad = isram_util_bad + wsram_util_bad or osram_util_bad
                xsram_util = (isram_util + wsram_util + osram_util) / 3

                # weighted total utilization
                total_util = array_util + xsram_util * 0.05

                if (total_util > total_util_max) and (not xsram_util_bad):
                    success_flag = True
                    n_max_util = n_appr
                    k_max_util = k_appr
                    c_max_util = c_appr
                    p_max_util = p_appr
                    q_max_util = q_appr
                    x_max_util = (p_appr - 1) * prob.Hstride + prob.R
                    y_max_util = (q_appr - 1) * prob.Wstride + prob.S
                    array_util_max_util = array_util
                    total_util_max = total_util
                    # print(total_util, array_util)
                
                n_appr = n_appr / 2
            c_appr = c_appr * 2
        k_appr = k_appr / 2

    assert success_flag
    # print(" array_util: ", array_util_max_util, " npqk: (", n_max_util, p_max_util, q_max_util, k_max_util, ")", " c: ", c_max_util, " row-col: (", num_rows, num_cols, ")")

    if prob.K < num_nodes:
        num_nodes_active = prob.K
    else:
        num_nodes_active = num_nodes

    k_max = k_max_util
    k_map = np.divmod(max_k_per_node, k_max)
    for k in range(int(k_map[0] + 1)):
        k_this = k_max if k < k_map[0] else k_map[1]
        if k_this == 0:
            break
        
        n_this = prob.N
        p_max = p_max_util
        q_max = q_max_util
        npq_max = n_this * p_max * q_max
        npq_map = np.divmod(prob.N * prob.P * prob.Q, npq_max)
        for npq in range(int(npq_map[0] + 1)):
            npq_this = npq_max if npq < npq_map[0] else npq_map[1]
            if npq_this == 0:
                break
            
            pq_this = min(npq_this / n_this, prob.P * prob.Q)
            # assume pq is square
            p_this = min(round2power(math.sqrt(pq_this)), prob.P)
            q_this = min(pq_this / p_this, prob.Q)
            osram_this = n_this * k_this * pq_this

            x_this = (p_this - 1) * prob.Hstride + prob.R
            y_this = (q_this - 1) * prob.Wstride + prob.S

            odram_byte_iw = 0
            compute_cycle_c_delta = 0
            
            c_max = c_max_util
            c_map = np.divmod(prob.C, c_max)
            for c in range(int(c_map[0] + 1)):
                c_this = c_max if c < c_map[0] else c_map[1]
                if c_this == 0:
                    break
                # every tile here is a tile that can fit in memory

                # get compute cycle
                # num_rows is 1, ignored
                under_util_isram_bw = math.ceil(num_rows / c_this)
                under_util_wosram_bw = math.ceil(num_cols / k_this)
                vector_map = math.ceil(c_this / num_rows) * math.ceil(k_this / max(round2power(num_cols / (prob.R * prob.S)), 1))

                # compute pipeline according to memory dependence
                #    QD	wR	oR	MAC	oW
                # 1. QD: get quantization and difference
                # 2. wR: read weight from wsram, and write new index in osram (ignored by coalescing)
                # 3. oR: read output from osram
                # 4. MAC: MAC done
                # 5. oW: write output to osram

                # to avoid conflicts, true dual-port SRAM is needed to ensure 1 mac per cycle without any stall.
                # if not true dual-port SRAM, stall will increase the cycle count per mac to 2.
                # check out cris_baseline_pipeline.xlsx for details

                # in this model, we assume 1 mac per cycle
                # also there exists an overhead of reading input, which is 1/num_cols.
                # the final cycle per mac is 1+1/num_cols
                cycle_overhead = 1 + 1/num_cols
                
                # compute cycles need to consider similarity
                num_inputs = npq_this * vector_map
                # the extra col_fold is due to conflict in accessing oSRAM from storing output and reading index
                
                compute_cycle_c = num_inputs * similarity * cycle_overhead
                compute_cycle += compute_cycle_c
                compute_cycle_c_delta += compute_cycle_c
                
                # get memory access during compute
                # for each input element
                # two oSRAM writes: one index and vector outputs
                # two oSRAM read: one index and vector outputs
                # one iSRAM read: one input
                # one wSRAM read: vector weights

                # read and write during compute
                osram_word_wr += num_inputs * (1 + num_cols * similarity)
                osram_word_rd += num_inputs * (1 + num_cols * similarity)

                # last read from oSRAM to DRAM
                osram_word_rd += osram_this

                # for every c tile, weights are written from dram to sram only once, but read multiple times
                wsram_this = c_this * (prob.R * prob.S) * k_this
                wsram_word_wr += wsram_this
                wsram_word_rd += num_inputs * num_cols * similarity

                # no read stall
                wsram_bytes = wsram_this * wsram_cfg["byte_per_word"]
                wsram_stall_wr += math.ceil(max(wsram_bytes / min((wsram_bw / under_util_wosram_bw), noc_bw) - compute_cycle_c, 0))
                noc_bytes += wsram_bytes * (num_nodes_active / 2 * (num_nodes_active / 2 + 1) / 2) * 2

                # input is written from dram to sram only once, but read multiple times
                isram_this = n_this * x_this * y_this * c_this
                isram_word_wr += isram_this
                isram_word_rd += num_inputs * similarity
                
                # no read stall
                isram_bytes = isram_this * isram_cfg["byte_per_word"]
                isram_stall_wr += math.ceil(max(isram_bytes / min((isram_bw / under_util_isram_bw), noc_bw) - compute_cycle_c, 0))
                noc_bytes += isram_bytes * num_nodes_active

                odram_byte_iw += wsram_bytes + isram_bytes

            # rd stall, odram read osram
            # no wr stall
            osram_stall_rd += math.ceil(max(osram_this * osram_cfg["byte_per_word"] / min((osram_bw / under_util_wosram_bw), noc_bw) - compute_cycle_c_delta, 0))
            noc_bytes += osram_this * osram_cfg["byte_per_word"] * num_nodes_active

            odram_stall += math.ceil(max((osram_this + odram_byte_iw) / odram_bw - compute_cycle_c_delta, 0))
    
    utilization = prob.flops * prob_cfg["cycle"] / (num_rows * num_cols * compute_cycle * 2 * num_nodes) * 100.

    compu_result = OrderedDict()
    compu_result["cycle"] = int(compute_cycle)
    compu_result["utilization"] = float(utilization)
    compu_result["flops"] = float(prob.flops)

    noc_result = copy.deepcopy(compu_result)
    noc_result["cycle"] = int(noc_bytes / noc_bw / num_nodes)

    reuse_result = copy.deepcopy(compu_result)

    isram_result = OrderedDict()
    isram_result["access stall"]    = OrderedDict({ "rd": int(isram_stall_rd), 
                                                    "wr": int(isram_stall_wr)})
    isram_result["byte total"]      = OrderedDict({ "rd": int(math.ceil(isram_word_rd * isram_cfg["byte_per_word"] * num_nodes_active)), 
                                                    "wr": int(math.ceil(isram_word_wr * isram_cfg["byte_per_word"] * num_nodes_active))})
    isram_result["access total"]    = OrderedDict({ "rd": int(math.ceil(isram_word_rd * isram_cfg["byte_per_word"] * num_nodes_active / isram_cfg["byte_per_brow"])), 
                                                    "wr": int(math.ceil(isram_word_wr * isram_cfg["byte_per_word"] * num_nodes_active / isram_cfg["byte_per_brow"]))})
    
    wsram_result = OrderedDict()
    wsram_result["access stall"]    = OrderedDict({ "rd": int(wsram_stall_rd), 
                                                    "wr": int(wsram_stall_wr)})
    wsram_result["byte total"]      = OrderedDict({ "rd": int(math.ceil(wsram_word_rd * wsram_cfg["byte_per_word"] * num_nodes_active)), 
                                                    "wr": int(math.ceil(wsram_word_wr * wsram_cfg["byte_per_word"] * num_nodes_active))})
    wsram_result["access total"]    = OrderedDict({ "rd": int(math.ceil(wsram_word_rd * wsram_cfg["byte_per_word"] * num_nodes_active / wsram_cfg["byte_per_brow"])), 
                                                    "wr": int(math.ceil(wsram_word_wr * wsram_cfg["byte_per_word"] * num_nodes_active / wsram_cfg["byte_per_brow"]))})

    osram_result = OrderedDict()
    osram_result["access stall"]    = OrderedDict({ "rd": int(osram_stall_rd), 
                                                    "wr": int(osram_stall_wr)})
    osram_result["byte total"]      = OrderedDict({ "rd": int(math.ceil(osram_word_rd * osram_cfg["byte_per_word"] * num_nodes_active)), 
                                                    "wr": int(math.ceil(osram_word_wr * osram_cfg["byte_per_word"] * num_nodes_active))})
    osram_result["access total"]    = OrderedDict({ "rd": int(math.ceil(osram_word_rd * osram_cfg["byte_per_word"] * num_nodes_active / osram_cfg["byte_per_brow"])), 
                                                    "wr": int(math.ceil(osram_word_wr * osram_cfg["byte_per_word"] * num_nodes_active / osram_cfg["byte_per_brow"]))})

    odram_result = OrderedDict()
    odram_byte_rd = isram_word_wr * isram_cfg["byte_per_word"] * num_nodes_active + wsram_word_wr * wsram_cfg["byte_per_word"] * num_nodes_active
    odram_byte_wr = prob.N * prob.P * prob.Q * prob.K * osram_cfg["byte_per_word"]
    odram_result["access stall"]    = OrderedDict({ "rd": 0 if odram_byte_rd == 0 else int(math.ceil(odram_stall * odram_byte_rd / (odram_byte_rd + odram_byte_wr))),
                                                    "wr": 0 if odram_byte_wr == 0 else int(math.ceil(odram_stall * odram_byte_wr / (odram_byte_rd + odram_byte_wr)))})
    odram_result["byte total"]      = OrderedDict({ "rd": int(odram_byte_rd), 
                                                    "wr": int(odram_byte_wr)})
    odram_result["access total"]    = OrderedDict({ "rd": int(math.ceil(odram_byte_rd / odram_cfg["bits_per_chip"])), 
                                                    "wr": int(math.ceil(odram_byte_wr / odram_cfg["bits_per_chip"]))})
    odram_result["page total"]      = OrderedDict({ "rd": int(math.ceil(odram_byte_rd * 8 / odram_cfg["bits_per_page"])), 
                                                    "wr": int(math.ceil(odram_byte_wr * 8 / odram_cfg["bits_per_page"]))})

    output["compu"] = compu_result
    output["noc"]   = noc_result
    output["reuse"] = reuse_result
    output["isram"] = isram_result
    output["wsram"] = wsram_result
    output["osram"] = osram_result
    output["odram"] = odram_result

    return output

