# -*- coding: utf-8 -*-
"""
Created on Tue Aug  9 13:47:02 2022

@author: Edoardo
"""

from algorithms.prabhakar.src.interval import IntervalLayer

"""
Class that holds the parameters of a compressed neural network
"""
class PrabhakarNet():
    
    """
    Pass list of weights and biases, plus centroid input to the first layer
    """
    def __init__(self, WeightMatrixList, BiasVectorList, ClusterList, SignedInput=False):
        
        assert len(WeightMatrixList) == len(BiasVectorList)
        assert len(ClusterList) == len(BiasVectorList) - 1
        
        self.layers = [None] * len(WeightMatrixList)
        
        # let the first layer accept negative inputs
        self.layers[0] = IntervalLayer(WeightMatrixList[0],
                                       BiasVectorList[0],
                                       SignedInput)
        
        # initialise all other layers
        for i in range(1, len(WeightMatrixList)):
            self.layers[i] = IntervalLayer(WeightMatrixList[i],
                                           BiasVectorList[i],
                                           False)
        
        # compress all layers except the last
        for i in range(0, len(WeightMatrixList) - 1):
            nAdjClusters = int((ClusterList[i] + 1) / 2) # divide by two with ceiling rounding
            self.layers[i].RandomCompress(nAdjClusters, self.layers[i+1])
        
    """
    Execute the whole compressed network
    """
    def Execute(self, UpperInput, LowerInput):
        
        LayerOutput = [None] * (len(self.layers) + 1)
        LayerOutput[0] = (UpperInput, LowerInput)
        
        for i in range(len(self.layers) - 1):
            LayerOutput[i + 1] = self.layers[i].Execute(*LayerOutput[i], WithReLU=True)
        
        LayerOutput[-1] = self.layers[-1].Execute(*LayerOutput[-2], WithReLU=False)
        
        return LayerOutput
    