# -*- coding: utf-8 -*-
"""
Created on Tue Aug  9 13:47:02 2022

@author: Edoardo
"""

from algorithms.elboher.src.merge import MergedLayer

"""
Divide an integer into two (unequal) halves, such that no odd unit is lost
"""
def IntHalves(n):
    
    h = int(n / 2)
    r = n - 2 * h
    
    return (h, h + r)

"""
Divide an integer into four (unequal) quarters, such that no odd unit is lost
and each quarter is not smaller than 1
"""
def IntQuarters(n):
    
    a, b = IntHalves(n)
    c, d = IntHalves(a)
    e, f = IntHalves(b)
    
    return (max(1,c), max(1,d), max(1,e), max(1,f))

"""
Class that holds the parameters of a compressed neural network
"""
class ElboherNet():
    
    """
    Pass list of weights and biases, plus centroid input to the first layer
    """
    def __init__(self, WeightMatrixList, BiasVectorList, ClusterList, SignedInput=False):
        
        assert len(WeightMatrixList) == len(BiasVectorList)
        assert len(ClusterList) == len(BiasVectorList) - 1
        
        self.layers = [None] * len(WeightMatrixList)
        
        # let the first layer accept negative inputs
        self.layers[0] = MergedLayer(WeightMatrixList[0],
                                     BiasVectorList[0],
                                     *IntQuarters(ClusterList[0]),
                                     SignedInput)
        
        # compress all hidden layers
        for i in range(1, len(WeightMatrixList) - 1):
            self.layers[i] = MergedLayer(WeightMatrixList[i],
                                         BiasVectorList[i],
                                         *IntQuarters(ClusterList[i]),
                                         False)
        
        # do not compress the final layer (it's linear anyway)
        n = len(BiasVectorList[-1])
        self.layers[-1] = MergedLayer(WeightMatrixList[-1],
                                      BiasVectorList[-1],
                                      n, n, n, n,
                                      False)
        
    """
    Execute the whole compressed network
    """
    def Execute(self, inIncPos, inIncNeg, inDecPos, inDecNeg):
        
        LayerOutput = [None] * (len(self.layers) + 1)
        LayerOutput[0] = (inIncPos, inIncNeg, inDecPos, inDecNeg)
        
        for i in range(len(self.layers) - 1):
            LayerOutput[i + 1] = self.layers[i].Execute(*LayerOutput[i], WithReLU=True)
        
        LayerOutput[-1] = self.layers[-1].Execute(*LayerOutput[-2], WithReLU=False)
        
        return LayerOutput
    