from collections import OrderedDict
from utils import bcolors, yaml_load, yaml_overwrite, arch_support_list, arch_attr_list, onchip_list
import os, importlib, copy


class ProblemCost:
    def __init__(self, architecture: OrderedDict) -> None:
        self.arch = architecture
        self.disc = OrderedDict()
        self.attr = arch_attr_list
        for item in self.arch.keys():
            if item.lower() not in self.attr:
                assert self.arch[item]["type"] in arch_support_list, \
                    bcolors.FAIL + "Unrecognized architecture type: " + self.arch[item]["type"] + bcolors.ENDC
                loaded_dict = importlib.import_module("cost")
                self.disc[item] = getattr(loaded_dict, self.arch[item]["type"]+"_dict")()

    def record(self, item: str, numbers: tuple):
        assert item in self.arch.keys(), \
                    bcolors.FAIL + "Unrecognized architecture item: " + item + bcolors.ENDC
        assert self.arch[item]["type"] in arch_support_list, \
                    bcolors.FAIL + "Unrecognized architecture type: " + self.arch[item]["type"] + bcolors.ENDC
        loaded_dict = importlib.import_module("cost")
        getattr(loaded_dict, self.arch[item]["type"]+"_dict_record")(self.disc[item], numbers)

    def publish(self, performance: OrderedDict):
        for item in self.arch.keys():
            if item.lower() not in self.attr:
                assert self.arch[item]["type"] in arch_support_list, \
                    bcolors.FAIL + "Unrecognized architecture type: " + self.arch[item]["type"] + bcolors.ENDC
                loaded_dict = importlib.import_module("cost")
                getattr(loaded_dict, self.arch[item]["type"]+"_dict_publish")(item, self.disc, performance)


class WorkloadCost:
    def __init__(self, architecture: OrderedDict, path: str) -> None:
        self.path = path
        self.arch = architecture
        self.copy = path + "/cost/workloadcost.yaml"
        self.copy_summary = path + "/cost/workloadcost.summary.yaml"
        if os.path.exists(self.copy):
            os.remove(self.copy)
        if os.path.exists(self.copy_summary):
            os.remove(self.copy_summary)
        self.disc = OrderedDict()
        self.disc["technology"] = architecture["technology"] # in nm
        self.disc["frequency"] = architecture["frequency"] # in MHz
        self.attr = OrderedDict({"dynamic": 0, "leakage": 0, "total": 0})
        self.disc["overall"] = OrderedDict({
                                            "area": OrderedDict({"onchip": 0}),
                                            "energy": OrderedDict({ "total":    copy.deepcopy(self.attr), 
                                                                    "onchip":   copy.deepcopy(self.attr), 
                                                                    }),
                                            "power": OrderedDict({  "total":    copy.deepcopy(self.attr), 
                                                                    "onchip":   copy.deepcopy(self.attr), 
                                                                    }),
                                            })

        for arch_type in arch_support_list:
            self.disc["overall"]["area"][arch_type] = 0
            self.disc["overall"]["energy"][arch_type] = copy.deepcopy(self.attr)
            self.disc["overall"]["power"][arch_type] = copy.deepcopy(self.attr)

        self.non_prob_keys = ["technology", "frequency", "overall"]
        self.num_probs = 0
        
    def record(self, probname, probcostdict: OrderedDict):
        self.num_probs += 1
        if os.path.exists(self.copy):
            # load checkpoint workload cost
            self.disc = yaml_load(self.copy)
        self.disc[probname] = probcostdict
        yaml_overwrite(self.copy, self.disc)

    def publish(self, performance: OrderedDict):
        for item in self.arch.keys():
            if item not in arch_attr_list:
                item_type = self.arch[item]["type"]
                if item_type in arch_support_list:
                    for prob in self.disc.keys():
                        if prob not in self.non_prob_keys:
                            self.disc["overall"]["area"][item_type]                     += self.disc[prob][item]["area"] # in mm^2
                            
                            for attr in self.attr.keys():
                                self.disc["overall"]["energy"][item_type][attr]         += self.disc[prob][item]["energy"][attr] # in nJ
                                self.disc["overall"]["energy"]["total"][attr]           += self.disc[prob][item]["energy"][attr] # in nJ

                            if item_type in onchip_list:
                                self.disc["overall"]["area"]["onchip"]                  += self.disc[prob][item]["area"] # in mm^2
                                for attr in self.attr.keys():
                                    self.disc["overall"]["energy"]["onchip"][attr]      += self.disc[prob][item]["energy"][attr] # in nJ

        for area_key in self.disc["overall"]["area"].keys():
            self.disc["overall"]["area"][area_key] /= self.num_probs # in mm^2

        for power_key in self.disc["overall"]["power"].keys():
            for attr in self.attr.keys():
                self.disc["overall"]["power"][power_key][attr] = self.disc["overall"]["energy"][power_key][attr] / performance["overall"]["runtime"]["impl"] / 10**6 # in mW

        yaml_overwrite(self.copy, self.disc)
        disc_summary = OrderedDict({"technology": self.disc["technology"], 
                                            "frequency": self.disc["frequency"],
                                            "overall": self.disc["overall"]})
        yaml_overwrite(self.copy_summary, disc_summary)

