# -*- coding: utf-8 -*-
"""
Created on Tue Jul 19 15:54:39 2022

@author: Edoardo
"""

import numpy as np

"""
Create two copies of the input by zeroing the negative and positive entries respectively
Arguments:
    M: input matrix, any numpy array
Returns:
    MatPos: copy of M retaining only the positive elements
    MatNeg: copy of M retaining only the negative elements
"""
def MatrixSignSplit(M):
    
    MatPos = M.copy()
    MatNeg = M.copy()
    MatPos[MatPos < 0] = 0
    MatNeg[MatNeg > 0] = 0
    
    return (MatPos, MatNeg)

"""
Class that holds the parameters of each partitioned layer
"""
class PartitionedLayer:
    
    """
    Partition the weights W and biases b into Pos/Neg and Inc/Dec neurons
    When negative inputs are expected, pass SignedInput=True to the first layer
    """
    def __init__(self, WeightMatrix, BiasVector, SignedInput=False):
        self.W = WeightMatrix
        self.b = BiasVector
        
        self.signed = SignedInput
        
        if SignedInput:
            SignedW = PartitionedLayer.SignedInputWeights(self.W)
            
            WeightPos, WeightNeg = MatrixSignSplit(SignedW)
            WeightZero = np.zeros(SignedW.shape)
        else:
            WeightPos, WeightNeg = MatrixSignSplit(self.W)
            WeightZero = np.zeros(self.W.shape)
        
        self.WIncPos = np.concatenate([WeightPos, WeightZero, WeightZero, WeightNeg], axis=1)
        self.WIncNeg = np.concatenate([WeightPos, WeightZero, WeightZero, WeightNeg], axis=1)
        self.WDecPos = np.concatenate([WeightZero, WeightNeg, WeightPos, WeightZero], axis=1)
        self.WDecNeg = np.concatenate([WeightZero, WeightNeg, WeightPos, WeightZero], axis=1)
        
        self.bIncPos = self.b
        self.bIncNeg = self.b
        self.bDecPos = self.b
        self.bDecNeg = self.b
        
    """
    Allow negative inputs with the ReLU split trick
    """
    def SignedInputWeights(WeightMatrix):
        
        SignedWeights = np.concatenate([WeightMatrix, -WeightMatrix], axis=1)
        
        return SignedWeights
    
    """
    Generate a correct interval input
    Take care of possible negative values
    """
    def PartitionedInput(self, inIncPos, inIncNeg, inDecPos, inDecNeg):
        
        if self.signed:
        
            sIncPos = np.concatenate([inIncPos, -inDecNeg])
            sIncNeg = np.concatenate([inIncPos, -inDecNeg])
            sDecPos = np.concatenate([inDecPos, -inIncNeg])
            sDecNeg = np.concatenate([inDecPos, -inIncNeg])
            
            sIncPos[sIncPos < 0] = 0
            sIncNeg[sIncNeg < 0] = 0
            sDecPos[sDecPos < 0] = 0
            sDecNeg[sDecNeg < 0] = 0
            
            FullInput = np.concatenate([sIncPos, sIncNeg, sDecPos, sDecNeg])
        
        else:
            FullInput = np.concatenate([inIncPos, inIncNeg, inDecPos, inDecNeg])
        
        return FullInput