from collections import OrderedDict
from utils import bcolors, yaml_load, yaml_overwrite, arch_support_list, arch_attr_list, memory_list, compute_list
import os, importlib, copy


class ProblemPerf:
    def __init__(self, architecture: OrderedDict) -> None:
        self.arch = architecture
        self.disc = OrderedDict()
        self.attr = OrderedDict({"arch": 0, "impl": 0})

        self.disc["overall"] = OrderedDict()
        self.disc["overall"]["utilization"] = copy.deepcopy(self.attr)
        self.disc["overall"]["bandwidth"] = OrderedDict()
        for memory in memory_list:
            self.disc["overall"]["bandwidth"][memory] = copy.deepcopy(self.attr)
        self.disc["overall"]["cycle"] = copy.deepcopy(self.attr)
        self.disc["overall"]["runtime"] = copy.deepcopy(self.attr)
        self.disc["overall"]["throughput"] = copy.deepcopy(self.attr)
        self.disc["overall"]["flops"] = 0

        for item in self.arch.keys():
            if item.lower() not in arch_attr_list:
                assert self.arch[item]["type"] in arch_support_list, \
                    bcolors.FAIL + "Unrecognized architecture type: " + self.arch[item]["type"] + bcolors.ENDC
                loaded_dict = importlib.import_module("performance")
                self.disc[item] = getattr(loaded_dict, self.arch[item]["type"]+"_dict")()

    def record(self, item: str, result: OrderedDict):
        assert item in self.arch.keys(), \
                    bcolors.FAIL + "Unrecognized architecture item: " + item + bcolors.ENDC
        assert self.arch[item]["type"] in arch_support_list, \
                    bcolors.FAIL + "Unrecognized architecture type: " + self.arch[item]["type"] + bcolors.ENDC
        loaded_dict = importlib.import_module("performance")
        getattr(loaded_dict, self.arch[item]["type"]+"_dict_record")(self.disc[item], result)

    def publish(self):
        total_stall = 0
        total_compute = 1
        total_byte = OrderedDict({memory: 0 for memory in  memory_list})
        for item in self.arch.keys():
            if item.lower() not in arch_attr_list:
                if self.arch[item]["type"] in memory_list:
                    total_stall += (self.disc[item]["access stall"]["rd"] + self.disc[item]["access stall"]["wr"])
                    for memory in memory_list:
                        if self.arch[item]["type"] == memory:
                            total_byte[memory] += (self.disc[item]["byte total"]["rd"] + self.disc[item]["byte total"]["wr"])
                if self.arch[item]["type"] in compute_list:
                    total_compute = max(total_compute, self.disc[item]["cycle"])

        self.disc["overall"]["cycle"]["arch"] = total_compute
        self.disc["overall"]["cycle"]["impl"] = total_compute + total_stall
        self.disc["overall"]["runtime"]["arch"] = self.disc["overall"]["cycle"]["arch"] / (self.arch["frequency"] * 10**6) # seconds
        self.disc["overall"]["runtime"]["impl"] = self.disc["overall"]["cycle"]["impl"] / (self.arch["frequency"] * 10**6) # seconds
        self.disc["overall"]["throughput"]["arch"] = 1 / self.disc["overall"]["runtime"]["arch"] # sample per second
        self.disc["overall"]["throughput"]["impl"] = 1 / self.disc["overall"]["runtime"]["impl"] # sample per second
        for memory in memory_list:
            self.disc["overall"]["bandwidth"][memory]["arch"] = total_byte[memory] / self.disc["overall"]["runtime"]["arch"] # byte per second
            self.disc["overall"]["bandwidth"][memory]["impl"] = total_byte[memory] / self.disc["overall"]["runtime"]["impl"] # byte per second
        
        total_util = 0
        total_flops = 0
        for item in self.arch.keys():
            if item.lower() not in arch_attr_list:
                if self.arch[item]["type"] in compute_list:
                    total_util += self.disc[item]["utilization"] * self.disc[item]["cycle"] * self.arch[item]["weight"]
                    total_flops += self.disc[item]["flops"] * self.arch[item]["weight"]

        self.disc["overall"]["utilization"]["arch"] = total_util / self.disc["overall"]["cycle"]["arch"] # percent
        self.disc["overall"]["utilization"]["impl"] = total_util / self.disc["overall"]["cycle"]["impl"] # percent
        self.disc["overall"]["flops"] = total_flops


class WorkloadPerf:
    def __init__(self, architecture: OrderedDict, path: str, required: bool) -> None:
        self.path = path
        self.arch = architecture
        self.copy = path + "/performance/workloadperf.yaml"
        self.copy_summary = path + "/performance/workloadperf.summary.yaml"
        if required is True:
            assert os.path.exists(self.copy), bcolors.FAIL + self.copy + " does not exit." + bcolors.ENDC
            assert os.path.exists(self.copy_summary), bcolors.FAIL + self.copy_summary + " does not exit." + bcolors.ENDC
        else:
            if os.path.exists(self.copy):
                os.remove(self.copy)
            if os.path.exists(self.copy_summary):
                os.remove(self.copy_summary)
        self.attr = OrderedDict({"arch": 0, "impl": 0})

        self.disc = OrderedDict()
        self.disc["technology"] = architecture["technology"] # in nm
        self.disc["frequency"] = architecture["frequency"] # in MHz
        self.disc["overall"] = OrderedDict({ "flops": 0,
                                        "flops per sec": 0,
                                        "total byte": OrderedDict(),
                                        "flops per byte": OrderedDict(),
                                        "cycle": copy.deepcopy(self.attr),
                                        "utilization": copy.deepcopy(self.attr),
                                        "bandwidth": OrderedDict(),
                                        "runtime": copy.deepcopy(self.attr),
                                        "throughput": copy.deepcopy(self.attr),
                                        })
        for memory in memory_list:
            self.disc["overall"]["total byte"][memory] = 0
            self.disc["overall"]["flops per byte"][memory] = 0
            self.disc["overall"]["bandwidth"][memory] = copy.deepcopy(self.attr)
        self.non_prob_keys = ["technology", "frequency", "overall"]

    def record(self, probname, probperfdict: OrderedDict):
        if os.path.exists(self.copy):
            # load checkpoint workload performance
            self.disc = yaml_load(self.copy)
        self.disc[probname] = probperfdict
        yaml_overwrite(self.copy, self.disc)

    def publish(self):
        total_util = 0
        for prob in self.disc.keys():
            if prob not in self.non_prob_keys:
                self.disc["overall"]["cycle"]["arch"] += self.disc[prob]["overall"]["cycle"]["arch"]
                self.disc["overall"]["cycle"]["impl"] += self.disc[prob]["overall"]["cycle"]["impl"]
                self.disc["overall"]["runtime"]["arch"] += self.disc[prob]["overall"]["runtime"]["arch"]
                self.disc["overall"]["runtime"]["impl"] += self.disc[prob]["overall"]["runtime"]["impl"]
                for memory in memory_list:
                    self.disc["overall"]["total byte"][memory] += self.disc[prob]["overall"]["runtime"]["arch"] * self.disc[prob]["overall"]["bandwidth"][memory]["arch"]
                total_util += self.disc[prob]["overall"]["cycle"]["arch"] * self.disc[prob]["overall"]["utilization"]["arch"]
                self.disc["overall"]["flops"] += self.disc[prob]["overall"]["flops"]
                
        self.disc["overall"]["runtime"]["arch"] = self.disc["overall"]["cycle"]["arch"] / (self.arch["frequency"] * 10**6) # in second
        self.disc["overall"]["runtime"]["impl"] = self.disc["overall"]["cycle"]["impl"] / (self.arch["frequency"] * 10**6) # in second

        self.disc["overall"]["throughput"]["arch"] = 1 / self.disc["overall"]["runtime"]["arch"] # sample per second
        self.disc["overall"]["throughput"]["impl"] = 1 / self.disc["overall"]["runtime"]["impl"] # sample per second

        self.disc["overall"]["utilization"]["arch"] = total_util / self.disc["overall"]["cycle"]["arch"] # percent
        self.disc["overall"]["utilization"]["impl"] = total_util / self.disc["overall"]["cycle"]["impl"] # percent

        for memory in memory_list:
            self.disc["overall"]["bandwidth"][memory]["arch"] = self.disc["overall"]["total byte"][memory] / self.disc["overall"]["runtime"]["arch"] # byte per second
            self.disc["overall"]["bandwidth"][memory]["impl"] = self.disc["overall"]["total byte"][memory] / self.disc["overall"]["runtime"]["impl"] # byte per second
            self.disc["overall"]["flops per byte"][memory] = self.disc["overall"]["flops"] / self.disc["overall"]["total byte"][memory] # in flops per byte

        self.disc["overall"]["flops per sec"] = self.disc["overall"]["flops"] / self.disc["overall"]["runtime"]["impl"] # in flops per second

        yaml_overwrite(self.copy, self.disc)
        disc_summary = OrderedDict({"technology": self.disc["technology"], 
                                            "frequency": self.disc["frequency"],
                                            "overall": self.disc["overall"]})
        yaml_overwrite(self.copy_summary, disc_summary)
