# -*- coding: utf-8 -*-
"""
Created on Fri Aug 19 10:33:45 2022

@author: Edoardo
"""

import numpy as np

"""
Convert a ReLU layer to inactive canonical form
A centroid input is required to compute activation states
"""
class CenteredLayer:
    
    """
    Partition the weights W and biases b, assuming inc/dec inputs and outputs
    When negative inputs are expected, pass NegInput=True to the first layer
    """
    def __init__(self, WeightMatrix, BiasVector, Centroid):
        self.W = WeightMatrix
        self.b = BiasVector
        self.xc = Centroid
        
        # compute the activation pattern at the centroid
        self.act = self.ExecuteOriginal(Centroid, WithReLU=False) >= 0
        
        # compute weights and biases of the linear reconstruction
        A = np.diag(self.act)
        self.Wlin = A @ self.W
        self.blin = A @ self.b
        
        # compute weights and biases of the now-inactive potentials
        S = np.diag(1 - 2 * self.act)
        self.Wpot = S @ self.W
        self.bpot = S @ self.b
    
    """
    Execute the original layer
    Return either the output (WithRelu=True) or the raw potentials
    """
    def ExecuteOriginal(self, InputVector, WithReLU=True):
        
        OutputVector = np.ravel(self.W @ InputVector) + np.ravel(self.b)
        
        if WithReLU:
            OutputVector[OutputVector < 0] = 0
        
        return OutputVector
    
    """
    Execute the centered layer with canonical ReLU activation functions
    """
    def ExecuteCentered(self, InputVector):
        
        OutLin = np.ravel(self.Wlin @ InputVector) + np.ravel(self.blin)
        OutPot = np.ravel(self.Wpot @ InputVector) + np.ravel(self.bpot)
        
        OutputVector = OutLin + OutPot * (OutPot >= 0)
        
        return OutputVector
    