# -*- coding: utf-8 -*-
"""
Created on Fri Aug 19 10:54:06 2022

@author: Edoardo
"""

import numpy as np
from algorithms.ginnacer.src.center import CenteredLayer

"""
Create two copies of the input by zeroing the negative and positive entries respectively
Arguments:
    M: input matrix, any numpy array
Returns:
    MatPos: copy of M retaining only the positive elements
    MatNeg: copy of M retaining only the negative elements
"""
def MatrixSignSplit(M):
    
    MatPos = M.copy()
    MatNeg = M.copy()
    MatPos[MatPos < 0] = 0
    MatNeg[MatNeg > 0] = 0
    
    return (MatPos, MatNeg)

"""
Prepare a centered layer for interval arithmetic operations
"""
class IntervalLayer(CenteredLayer):
    
    """
    Partition the weights W and biases b, assuming inc/dec inputs and outputs
    When negative inputs are expected, pass SignedInput=True to the first layer
    When no ReLU activations must be applied (e.g. last layer), pass LinearOutput=True
    """
    def __init__(self, WeightMatrix, BiasVector, Centroid, SignedInput=False, LinearOutput=False):
        
        CenteredLayer.__init__(self, WeightMatrix, BiasVector, Centroid)
        
        self.signed = SignedInput
        self.linear = LinearOutput
        
        # deal with  the possibility of negative inputs
        if SignedInput:
            SignedWlin = IntervalLayer.SignedInputWeights(self.Wlin)
            SignedWpot = IntervalLayer.SignedInputWeights(self.Wpot)
            self.WlinInc, self.WlinDec = IntervalLayer.IntervalWeights(SignedWlin)
            self.WpotInc, self.WpotDec = IntervalLayer.IntervalWeights(SignedWpot)
            self.blinInc, self.blinDec = IntervalLayer.IntervalBiases(self.blin)
            self.bpotInc, self.bpotDec = IntervalLayer.IntervalBiases(self.bpot)
        
        # simpler computation when the input is always positive
        else:
            self.WlinInc, self.WlinDec = IntervalLayer.IntervalWeights(self.Wlin)
            self.WpotInc, self.WpotDec = IntervalLayer.IntervalWeights(self.Wpot)
            self.blinInc, self.blinDec = IntervalLayer.IntervalBiases(self.blin)
            self.bpotInc, self.bpotDec = IntervalLayer.IntervalBiases(self.bpot)
        
        # prepare full linear weights if required
        if LinearOutput:
            FullWlin = self.W
            if SignedInput:
                FullWlin = IntervalLayer.SignedInputWeights(FullWlin)
            self.WlinInc, self.WlinDec = IntervalLayer.IntervalWeights(FullWlin)
            self.blinInc, self.blinDec = IntervalLayer.IntervalBiases(self.b)
    
    """
    Allow negative inputs with the ReLU split trick
    """
    def SignedInputWeights(WeightMatrix):
        
        SignedWeights = np.concatenate([WeightMatrix, -WeightMatrix], axis=1)
        
        return SignedWeights
    
    """
    Partition the weight matrix into upper bound (Inc) and lower bound (Dec)
    """
    def IntervalWeights(WeightMatrix):
        
        PosWeights, NegWeights = MatrixSignSplit(WeightMatrix)
        
        IncWeights = np.concatenate([PosWeights, NegWeights], axis=1)
        DecWeights = np.concatenate([NegWeights, PosWeights], axis=1)
        
        return (IncWeights, DecWeights)
    
    """
    Partition the bias vector into upper bound (Inc) and lower bound (Dec)
    """
    def IntervalBiases(BiasVector):
        
        IncBiases = BiasVector
        DecBiases = BiasVector
        
        return (IncBiases, DecBiases)
    
    """
    Generate a correct interval input
    Take care of possible negative values
    """
    def IntervalInput(self, IncInput, DecInput):
        
        if self.signed:
        
            SignedInc = np.concatenate([IncInput, -DecInput])
            SignedDec = np.concatenate([DecInput, -IncInput])
            
            SignedInc[SignedInc < 0] = 0
            SignedDec[SignedDec < 0] = 0
            
            FullInput = np.concatenate([SignedInc, SignedDec])
        
        else:
            FullInput = np.concatenate([IncInput, DecInput])
        
        return FullInput
    
    """
    Compute the output of the interval layer for a given input
    """
    def Execute(self, IncInput, DecInput):
        
        FullInput = self.IntervalInput(IncInput, DecInput)
        
        # linear part
        IncOutput = np.ravel(self.WlinInc @ FullInput) + np.ravel(self.blinInc)
        DecOutput = np.ravel(self.WlinDec @ FullInput) + np.ravel(self.blinDec)
        
        # ReLU part
        if not self.linear:
            IncPot = np.ravel(self.WpotInc @ FullInput) + np.ravel(self.bpotInc)
            DecPot = np.ravel(self.WpotDec @ FullInput) + np.ravel(self.bpotDec)
            
            IncOutput = IncOutput + IncPot * (IncPot > 0)
            DecOutput = DecOutput + DecPot * (DecPot > 0)
        
        return (IncOutput, DecOutput)
