# -*- coding: utf-8 -*-
"""
Created on Wed Aug 24 13:48:13 2022

@author: Edoardo
"""

import numpy as np

"""
Create two copies of the input by zeroing the negative and positive entries respectively
Arguments:
    M: input matrix, any numpy array
Returns:
    MatPos: copy of M retaining only the positive elements
    MatNeg: copy of M retaining only the negative elements
"""
def MatrixSignSplit(M):
    
    MatPos = M.copy()
    MatNeg = M.copy()
    MatPos[MatPos < 0] = 0
    MatNeg[MatNeg > 0] = 0
    
    return (MatPos, MatNeg)

"""
Class that holds the parameters of a fastlin layer
"""
class BoundedLayer:
    
    """
    Turn a ReLU layer into a linear interval layer between the given bounds
    """
    def __init__(self, WeightMatrix, BiasVector, InMax, InMin):
        
        Wpos, Wneg = MatrixSignSplit(WeightMatrix)
        
        # interval weights and biases
        self.Wup = np.concatenate([Wpos, Wneg], axis=1)
        self.Wlow = np.concatenate([Wneg, Wpos], axis=1)
        self.bup = BiasVector
        self.blow = BiasVector
        
        # propagate bounds onto the potentials
        x = np.concatenate([InMax, InMin])
        self.ymax = np.ravel(self.Wup @ x) + np.ravel(self.bup)
        self.ymin = np.ravel(self.Wlow @ x) + np.ravel(self.blow)
        
        # precompute the ReLU linearization coefficients
        self.c = np.ones(len(BiasVector))
        nnz = np.logical_not(np.isclose(self.ymax - self.ymin, 0))
        numer = (self.ymax[nnz] * (self.ymax[nnz] >= 0) -
                 self.ymin[nnz] * (self.ymin[nnz] >= 0))
        denom = self.ymax[nnz] - self.ymin[nnz]
        self.c[nnz] = numer / denom
        
        # precompute the ReLU linearization offset
        self.d = self.ymin * (self.ymin >= 0)
    
    """
    Return the output bounds of this layer
    """
    def OutputBounds(self):
        
        OutMax = self.ymax.copy()
        OutMin = self.ymin.copy()
        
        OutMax[OutMax < 0] = 0
        OutMin[OutMin < 0] = 0
        
        return (OutMax, OutMin)
    
    """
    Execute the fastlin layer with a concrete input
    """
    def Execute(self, UpperInput, LowerInput, WithReLU=True):
        
        x = np.concatenate([UpperInput, LowerInput])
        
        yup = np.ravel(self.Wup @ x) + np.ravel(self.bup)
        ylow = np.ravel(self.Wlow @ x) + np.ravel(self.blow)
        
        if WithReLU:
            yup = (yup - self.ymin) * self.c + self.d
            ylow = ylow * self.c
            
        return (yup, ylow)
    