import subprocess, os, time
from collections import OrderedDict
from utils import bcolors


def cacti_run(
    tech_node_nm: str,
    cnfg_dict: OrderedDict,
    origin_cnfg_file=None,
    target_cnfg_file=None,
    result_file=None
):
    """
    run the cacti with input configuration, work for SRAM, whose size is either calculated to match the bw or pre-specified
    """
    original = open(origin_cnfg_file, "r")
    target   = open(target_cnfg_file, "w")

    tech_node_um = float(int(tech_node_nm)/1000)

    target.write("-size (bytes) " + str(cnfg_dict["byte_tot_sram"]) + "\n")
    target.write("-block size (bytes) " + str(cnfg_dict["byte_per_brow"]) + "\n")
    target.write("-technology (u) " + str(tech_node_um) + "\n")
    target.write("-UCA bank count "+ str(cnfg_dict["bank_per_sram"]) + "\n")
    target.write("-output/input bus width " + str(cnfg_dict["byte_per_brow"] * 8) + "\n")
    
    for entry in original:
        target.write(entry)
    
    original.close()
    target.close()

    if not os.path.exists("./implement/cacti7/cacti"):
        subprocess.call(["make", "all"], shell=True, cwd="./implement/cacti7/")
        time.sleep(20)
    
    tartget_cacti_name = result_file.replace("//", "/").replace("/", "_").split(".")[0]

    rep_cmd = "cp ./cacti ./cacti_" + tartget_cacti_name
    subprocess.call([rep_cmd], shell=True, cwd="./implement/cacti7/")
    final_cmd = "./cacti_" + tartget_cacti_name + " -infile " + target_cnfg_file + " > " + result_file
    subprocess.call([final_cmd], shell=True, cwd="./implement/cacti7/")
    rm_cmd = "rm -rf ./cacti_" + tartget_cacti_name
    subprocess.call([rm_cmd], shell=True, cwd="./implement/cacti7/")


def parse_report_line_count(
    report=None,
    index=64
):
    # get the area and power numbers in the report for final memory power and energy estimation.
    cacti_out = open(report, "r")
    line_idx = 0
    for entry in cacti_out:
        line_idx += 1
    broken_rpt = line_idx <= index
    return broken_rpt


def parse_report(
    report=None
):
    # get the area and power numbers in the report for final memory power and energy estimation.
    cacti_out = open(report, "r")

    line_idx = 0
    for entry in cacti_out:
        line_idx += 1
        if line_idx == 12:
            ram_type = entry.strip().split(":")[-1].strip()
            assert ram_type == "Scratch RAM", "Invalid SRAM type."
        if line_idx == 50:
            # unit: ns
            bank = float(entry.strip().split(":")[-1].strip())
        if line_idx == 58:
            # unit: ns
            access_time = float(entry.strip().split(":")[-1].strip())
        if line_idx == 59:
            # unit: ns
            cycle_time = float(entry.strip().split(":")[-1].strip())
        if line_idx == 60:
            # unit: nJ
            dynamic_energy_rd = float(entry.strip().split(":")[-1].strip())
        if line_idx == 61:
            # unit: nJ
            dynamic_energy_wr = float(entry.strip().split(":")[-1].strip())
        if line_idx == 62:
            # unit: mW
            leakage_power_bank = float(entry.strip().split(":")[-1].strip())
        if line_idx == 63:
            # unit: mW
            gate_leakage_power_bank = float(entry.strip().split(":")[-1].strip())
        if line_idx == 64:
            # unit: mm^2
            height = float(entry.strip().split(":")[-1].split("x")[0].strip())
            width = float(entry.strip().split(":")[-1].split("x")[1].strip())
            area = height * width

        if line_idx == 2:
            # unit: byte
            block_sz_bytes = float(entry.strip().split(":")[-1].strip())
    
    broken_rpt = parse_report_line_count(report, 64)
    assert not broken_rpt, \
        bcolors.FAIL + "Check " + report + " for sram cacti failure." + bcolors.ENDC

    # MHz
    max_freq = 1 / cycle_time * 1000
    # nJ
    energy_per_access_rd = dynamic_energy_rd
    # nJ
    energy_per_access_wr = dynamic_energy_wr
    # mW
    leakage_power = (leakage_power_bank + gate_leakage_power_bank) * bank
    # mm^2
    total_area = area
    cacti_out.close()
    return block_sz_bytes, max_freq, energy_per_access_rd, energy_per_access_wr, leakage_power, total_area


def query(name: str, technology: float, frequency: float, cnfg_sram: OrderedDict, odir: str):
    origin_cnfg_file = os.path.dirname(os.path.abspath(__file__)) + "/sram.cfg"
    target_cnfg_file = odir + "/" + name + ".sram.cacti.cfg"
    cacti_report = odir + "/" + name + ".sram.cacti.rpt"

    target_cnfg_file_flag = target_cnfg_file + ".flag"
    # cacti_run(technology, cnfg_sram, origin_cnfg_file, target_cnfg_file, cacti_report)
    if not os.path.exists(target_cnfg_file_flag):
        cacti_run(technology, cnfg_sram, origin_cnfg_file, target_cnfg_file, cacti_report)
        f = open(target_cnfg_file_flag, "a")
        f.write(target_cnfg_file_flag)
        f.close()
    else:
        broken_rpt = parse_report_line_count(cacti_report, 64)
        if broken_rpt:
            os.remove(cacti_report)
            os.remove(target_cnfg_file_flag)
            cacti_run(technology, cnfg_sram, origin_cnfg_file, target_cnfg_file, cacti_report)

    _, _, energy_per_access_rd, energy_per_access_wr, leakage_power, total_area = parse_report(cacti_report)
    return energy_per_access_rd, energy_per_access_wr, leakage_power, total_area


if __name__ == "__main__":
    pass

