from collections import OrderedDict
from utils import bcolors
import math


class conv:
    def __init__(self, name, cnfg: OrderedDict) -> None:
        self.name = name
        self.cnfg = cnfg
        
        assert cnfg["type"] == __class__.__name__, \
            bcolors.FAIL + "Problem type <" + cnfg["type"] + "> is invalid." + bcolors.ENDC
        
        self.cycle = cnfg["cycle"] # execution cycle for a multiplication/MAC
        if "sparsity" not in cnfg.keys():
            self.sparsity = 0.
        else:
            self.sparsity = cnfg["sparsity"]

        if "similarity" not in cnfg.keys():
            self.similarity = 0.
        else:
            self.similarity = cnfg["similarity"]
        
        self.N = cnfg["N"] # batch
        self.X = cnfg["X"] # input feature map height
        self.Y = cnfg["Y"] # input feature map width
        self.C = cnfg["C"] # input channel
        self.R = cnfg["R"] # filter height
        self.S = cnfg["S"] # filter width
        self.K = cnfg["K"] # output channel
        self.Hdilation = cnfg["Hdilation"] if "Hdilation" in cnfg.keys() else 1 # input dilation in height dimension
        self.Wdilation = cnfg["Wdilation"] if "Wdilation" in cnfg.keys() else 1 # input dilation in width dimension
        self.Hstride = cnfg["Hstride"] if "Hstride" in cnfg.keys() else 1 # input stride in height dimension
        self.Wstride = cnfg["Wstride"] if "Wstride" in cnfg.keys() else 1 # input stride in width dimension
        self.Hpadding = cnfg["Hpadding"] if "Hpadding" in cnfg.keys() else 0  # input padding in height dimension
        self.Wpadding = cnfg["Wpadding"] if "Wpadding" in cnfg.keys() else 0  # input padding in width dimension

        # following assume no padding
        self.P = int(math.floor((self.X + self.Hpadding * 2 - self.Hdilation * (self.R - 1) - 1 + self.Hstride) / self.Hstride)) # output height
        self.Q = int(math.floor((self.Y + self.Wpadding * 2 - self.Wdilation * (self.S - 1) - 1 + self.Wstride) / self.Wstride)) # output width
        self.X_valid = (self.Hstride * (self.P - 1) + self.R) if self.Hstride < self.R else (self.R * self.P)
        self.Y_valid = (self.Wstride * (self.Q - 1) + self.S) if self.Wstride < self.S else (self.S * self.Q)
        self.num_inputs =  self.N * self.C * self.X * self.Y
        self.num_inputs_valid =  self.N * self.C * self.X_valid * self.Y_valid
        self.num_weights = self.K * self.C * self.R * self.S
        self.num_outputs = self.N * self.K * self.P * self.Q
        self.flops = 2. * self.N * self.P * self.Q * (self.R * self.S) * self.K * self.C


class avgpool2d:
    def __init__(self, name, cnfg: OrderedDict) -> None:
        self.name = name
        self.cnfg = cnfg
        
        assert cnfg["type"] == __class__.__name__, \
            bcolors.FAIL + "Problem type <" + cnfg["type"] + "> is invalid." + bcolors.ENDC
        
        self.cycle = cnfg["cycle"] # execution cycle for a multiplication/MAC
        if "sparsity" not in cnfg.keys():
            self.sparsity = 1.
        else:
            self.sparsity = cnfg["sparsity"]
        
        self.N = cnfg["N"] # batch
        self.X = cnfg["X"] # input feature map height
        self.Y = cnfg["Y"] # input feature map width
        self.C = cnfg["C"] # input channel
        self.R = cnfg["R"] # filter height
        self.S = cnfg["S"] # filter width
        self.Hstride = cnfg["Hstride"] if "Hstride" in cnfg.keys() else 1 # input stride in height dimension
        self.Wstride = cnfg["Wstride"] if "Wstride" in cnfg.keys() else 1 # input stride in width dimension
        self.Hpadding = cnfg["Hpadding"] if "Hpadding" in cnfg.keys() else 0  # input padding in height dimension
        self.Wpadding = cnfg["Wpadding"] if "Wpadding" in cnfg.keys() else 0  # input padding in width dimension

        # following assume no padding
        self.P = int(math.floor((self.X + self.Hpadding * 2 - self.R + self.Hstride) / self.Hstride)) # output height
        self.Q = int(math.floor((self.Y + self.Wpadding * 2 - self.S + self.Wstride) / self.Wstride)) # output width
        self.X_valid = (self.Hstride * (self.P - 1) + self.R) if self.Hstride < self.R else (self.R * self.P)
        self.Y_valid = (self.Wstride * (self.Q - 1) + self.S) if self.Wstride < self.S else (self.S * self.Q)
        self.num_inputs =  self.N * self.C * self.X * self.Y
        self.num_inputs_valid =  self.N * self.C * self.X_valid * self.Y_valid
        self.num_outputs = self.N * self.C * self.P * self.Q
        self.flops = 2. * self.N * self.P * self.Q * (self.R * self.S) * self.C


class spike:
    def __init__(self, name, cnfg: OrderedDict) -> None:
        self.name = name
        self.cnfg = cnfg
        
        assert cnfg["type"] == __class__.__name__, \
            bcolors.FAIL + "Problem type <" + cnfg["type"] + "> is invalid." + bcolors.ENDC
        
        self.cycle = cnfg["cycle"] # execution cycle for a multiplication/MAC
        if "sparsity" not in cnfg.keys():
            self.sparsity = 1.
        else:
            self.sparsity = cnfg["sparsity"]
        
        self.N = cnfg["N"] # batch
        self.X = cnfg["X"] # input feature map height
        self.Y = cnfg["Y"] # input feature map width
        self.C = cnfg["C"] # input channel

        self.num_inputs = self.C * self.X * self.Y
        self.num_outputs = self.N * self.C * self.X * self.Y
        self.flops = self.num_outputs


class elementwise:
    def __init__(self, name, cnfg: OrderedDict) -> None:
        pass


class softmax:
    def __init__(self, name, cnfg: OrderedDict) -> None:
        pass

