# -*- coding: utf-8 -*-
"""
Created on Tue Aug  9 13:47:02 2022

@author: Edoardo
"""

import numpy as np
from algorithms.ginnacer.src.center import CenteredLayer
from algorithms.ginnacer.src.interval import IntervalLayer
from algorithms.ginnacer.src.merge import MergedLayer

"""
Generate and execute a neural network abstraction
The abstraction is tight around the centroid (zero over-approximation error)
"""
class AbstractNet():
    
    """
    Pass list of weights and biases, plus an input to the network (centroid)
    """
    def __init__(self, WeightList, BiasList, Centroid):
        
        assert len(WeightList) == len(BiasList)
        
        self.layers = [None] * len(WeightList)
        
        # allow negative inputs to the first layer
        self.layers[0] = MergedLayer(WeightList[0],
                                     BiasList[0],
                                     Centroid,
                                     SignedInput=True)
        
        for i in range(1, len(WeightList) - 1):
            
            # propagate the centroid forward
            Centroid = CenteredLayer.ExecuteOriginal(self.layers[i - 1],
                                                     Centroid,
                                                     WithReLU=True)
            
            # create the next compressed hidden layer
            self.layers[i] = MergedLayer(WeightList[i],
                                         BiasList[i],
                                         Centroid,
                                         SignedInput=False)
        
        # the centroid does not matter for the final layer
        Centroid = np.zeros(len(BiasList[-2]))
        
        # do not compress the final layer (it's linear anyway)
        self.layers[-1] = IntervalLayer(WeightList[-1],
                                        BiasList[-1],
                                        Centroid,
                                        SignedInput=False,
                                        LinearOutput=True)
        
    """
    Execute the whole network abstraction
    Returns output intervals for all layers
    """
    def Execute(self, IncInput, DecInput):
        
        LayerOutput = [None] * (len(self.layers) + 1)
        LayerOutput[0] = (IncInput, DecInput)
        
        for i in range(len(self.layers)):
            LayerOutput[i + 1] = self.layers[i].Execute(*LayerOutput[i])
        
        return LayerOutput
