import subprocess, os, time
from collections import OrderedDict
from utils import bcolors


def cacti_run(
    tech_node_nm: str,
    cnfg_dict: OrderedDict,
    origin_cnfg_file=None,
    target_cnfg_file=None,
    result_file=None
):
    """
    run the cacti with input configuration, work for DRAM, whose size is either calculated to match the bw or pre-specified
    """
    original = open(origin_cnfg_file, "r")
    target   = open(target_cnfg_file, "w")

    tech_node_um = float(int(tech_node_nm)/1000)

    target.write("-size (bytes) " + str(cnfg_dict["byte_tot_dram"]) + "\n")
    target.write("-page size (bits) " + str(cnfg_dict["bits_per_page"]) + "\n")
    target.write("-technology (u) " + str(tech_node_um) + "\n")
    target.write("-UCA bank count "+ str(cnfg_dict["bank_per_chip"]) + "\n")
    target.write("-internal prefetch width "+ str(cnfg_dict["prefetch"]) + "\n")
    
    for entry in original:
        target.write(entry)
    
    original.close()
    target.close()

    if not os.path.exists("./implement/cacti7/cacti"):
        subprocess.call(["make", "all"], shell=True, cwd="./implement/cacti7/")
        time.sleep(20)
    
    tartget_cacti_name = result_file.replace("//", "/").replace("/", "_").split(".")[0]

    rep_cmd = "cp ./cacti ./cacti_" + tartget_cacti_name
    subprocess.call([rep_cmd], shell=True, cwd="./implement/cacti7/")
    final_cmd = "./cacti_" + tartget_cacti_name + " -infile " + target_cnfg_file + " > " + result_file
    subprocess.call([final_cmd], shell=True, cwd="./implement/cacti7/")
    rm_cmd = "rm -rf ./cacti_" + tartget_cacti_name
    subprocess.call([rm_cmd], shell=True, cwd="./implement/cacti7/")


def parse_report_line_count(
    report=None,
    index=69
):
    # get the area and power numbers in the report for final memory power and energy estimation.
    cacti_out = open(report, "r")
    line_idx = 0
    for entry in cacti_out:
        line_idx += 1
    broken_rpt = line_idx <= index
    return broken_rpt


def parse_report(
    report=None
):
    # get the area and power numbers in the report for final memory power and energy estimation.
    cacti_out = open(report, "r")

    line_idx = 0
    for entry in cacti_out:
        line_idx += 1
        if line_idx == 12:
            ram_type = entry.strip().split(":")[-1].strip()
            assert ram_type == "Scratch RAM", "Invalid DRAM type."
        if line_idx == 50:
            # unit: ns
            bank = float(entry.strip().split(":")[-1].strip())
        if line_idx == 58:
            # unit: ns
            access_time = float(entry.strip().split(":")[-1].strip())
        if line_idx == 59:
            # unit: ns
            cycle_time = float(entry.strip().split(":")[-1].strip())
        if line_idx == 61:
            # unit: nJ
            activate_energy = float(entry.strip().split(":")[-1].strip())
        if line_idx == 62:
            # unit: nJ
            energy_rd = float(entry.strip().split(":")[-1].strip())
        if line_idx == 63:
            # unit: nJ
            energy_wr = float(entry.strip().split(":")[-1].strip())
        if line_idx == 64:
            # unit: nJ
            precharge_energy = float(entry.strip().split(":")[-1].strip())
        if line_idx == 65:
            # unit: mW
            leakage_power_closed_page = float(entry.strip().split(":")[-1].strip())
        if line_idx == 66:
            # unit: mW
            leakage_power_open_page = float(entry.strip().split(":")[-1].strip())
        if line_idx == 67:
            # unit: mW
            leakage_power_IO = float(entry.strip().split(":")[-1].strip())
        if line_idx == 68:
            # unit: mW
            refresh_power = float(entry.strip().split(":")[-1].strip())
        if line_idx == 69:
            # unit: mm^2
            height = float(entry.strip().split(":")[-1].split("x")[0].strip())
            width = float(entry.strip().split(":")[-1].split("x")[1].strip())
            area = height * width

    broken_rpt = parse_report_line_count(report, 69)
    assert not broken_rpt, \
        bcolors.FAIL + "Check " + report + " for dram cacti failure." + bcolors.ENDC

    # MHz
    max_freq = 1 / cycle_time * 1000
    cacti_out.close()
    return max_freq, activate_energy, energy_rd, energy_wr, precharge_energy, leakage_power_closed_page, leakage_power_open_page, leakage_power_IO, refresh_power, area


def query(name: str, technology: float, frequency: float, cnfg_dram: OrderedDict, odir: str):
    if "module" in cnfg_dram.keys() and "module" == "ddr4":
        cnfg_file_postfix = "/ddr4.cfg"
    else:
        cnfg_file_postfix = "/ddr3.cfg"
    origin_cnfg_file = os.path.dirname(os.path.abspath(__file__)) + cnfg_file_postfix
    target_cnfg_file = odir + "/" + name + ".dram.cacti.cfg"
    cacti_report = odir + "/" + name + ".dram.cacti.rpt"

    target_cnfg_file_flag = target_cnfg_file + ".flag"
    # cacti_run(technology, cnfg_dram, origin_cnfg_file, target_cnfg_file, cacti_report)
    if not os.path.exists(target_cnfg_file_flag):
        cacti_run(technology, cnfg_dram, origin_cnfg_file, target_cnfg_file, cacti_report)
        f = open(target_cnfg_file_flag, "a")
        f.write(target_cnfg_file_flag)
        f.close()
    else:
        broken_rpt = parse_report_line_count(cacti_report, 69)
        if broken_rpt:
            os.remove(cacti_report)
            os.remove(target_cnfg_file_flag)
            cacti_run(technology, cnfg_dram, origin_cnfg_file, target_cnfg_file, cacti_report)

    IO_power_factor = 1. if "embedded" not in cnfg_dram.keys() or cnfg_dram["embedded"] is False else 0.

    bank_count = cnfg_dram["bank_per_chip"]
    max_freq, activate_energy, energy_rd, energy_wr, precharge_energy, leakage_power_closed_page, leakage_power_open_page, leakage_power_IO, refresh_power, area = parse_report(cacti_report)
    return activate_energy + precharge_energy, \
            energy_rd, \
            energy_wr, \
            (leakage_power_open_page   + leakage_power_IO * IO_power_factor + refresh_power) * bank_count, \
            (leakage_power_closed_page + leakage_power_IO * IO_power_factor + refresh_power) * bank_count, \
            area


if __name__ == "__main__":
    pass

