# -*- coding: utf-8 -*-
"""
Created on Tue Aug 23 14:11:21 2022

@author: Edoardo
"""

import numpy as np
from random import sample

"""
Create two copies of the input by zeroing the negative and positive entries respectively
Arguments:
    M: input matrix, any numpy array
Returns:
    MatPos: copy of M retaining only the positive elements
    MatNeg: copy of M retaining only the negative elements
"""
def MatrixSignSplit(M):
    
    MatPos = M.copy()
    MatNeg = M.copy()
    MatPos[MatPos < 0] = 0
    MatNeg[MatNeg > 0] = 0
    
    return (MatPos, MatNeg)

"""
Class that holds the parameters of each layer of the interval network
"""
class IntervalLayer():
    
    """
    When negative inputs are expected, pass SignedInput=True to the first layer
    """
    def __init__(self, WeightMatrix, BiasVector, SignedInput=False):
        
        self.signed = SignedInput
        
        # deal with possible negative inputs
        if self.signed:
            self.Wup = IntervalLayer.SignedInputWeights(WeightMatrix)
            self.Wlow = IntervalLayer.SignedInputWeights(WeightMatrix)
        else:
            self.Wup = WeightMatrix
            self.Wlow = WeightMatrix
        self.bup = BiasVector
        self.blow = BiasVector
        
        # initialize cluster info
        self.ccounts = np.ones(len(BiasVector))
        self.nlabels = np.array([i for i in range(len(BiasVector))])
    
    """
    Allow negative inputs with the ReLU split trick
    """
    def SignedInputWeights(WeightMatrix):
        
        SignedWeights = np.concatenate([WeightMatrix, -WeightMatrix], axis=1)
        
        return SignedWeights
    
    """
    Generate a correct interval input
    Take care of possible negative values
    """
    def IntervalInput(self, UpperInput, LowerInput):
        
        if self.signed:
        
            SignedUpper = np.concatenate([UpperInput, -LowerInput])
            SignedLower = np.concatenate([LowerInput, -UpperInput])
            
            SignedUpper[SignedUpper < 0] = 0
            SignedLower[SignedLower < 0] = 0
            
            FullInput = np.concatenate([SignedUpper, SignedLower])
        
        else:
            FullInput = np.concatenate([UpperInput, LowerInput])
        
        return FullInput
    
    """
    Abstract the input weights of the current layer by merging neurons i and j
    """
    def RightAbstraction(Wcur, bcur, i, j, numpyOp):
        
        assert i < j
        
        Wabs = Wcur.copy()
        babs = bcur.copy()
        
        # replace neuron i with the abstraction {i,j}
        Wabs[i,:] = numpyOp(Wcur[i,:], Wcur[j,:])
        babs[i] = numpyOp(bcur[i], bcur[j])
        
        # erase neuron j
        Wabs = np.concatenate([Wabs[:j,:], Wabs[j+1:,:]], axis=0)
        babs = np.concatenate([babs[:j], babs[j+1:]])
        
        return (Wabs, babs)
    
    """
    Abstract the output weights of the current layer by merging neurons i and j
    """
    def LeftAbstraction(Wnxt, i, j, numpyOp):
        
        assert i < j
        
        Wabs = Wnxt.copy()
        
        # replace neuron i with the abstraction {i,j}
        Wabs[:,i] = numpyOp(Wnxt[:,i], Wnxt[:,j])
        
        # erase neuron j
        Wabs = np.concatenate([Wabs[:,:j], Wabs[:,j+1:]], axis=1)
        
        return Wabs
    
    """
    Abstract the input and output weights of the current layer
    by computing an upper and lower bound on neurons i and j
    """
    def MergeNeuronPair(self, i, j, NextLayer):
        
        self.Wup, self.bup = IntervalLayer.RightAbstraction(self.Wup, self.bup, i, j, np.maximum)
        NextLayer.Wup = IntervalLayer.LeftAbstraction(NextLayer.Wup, i, j, np.maximum)
        
        self.Wlow, self.blow = IntervalLayer.RightAbstraction(self.Wlow, self.blow, i, j, np.minimum)
        NextLayer.Wlow = IntervalLayer.LeftAbstraction(NextLayer.Wlow, i, j, np.minimum)
        
        self.ccounts[i] = self.ccounts[i] + self.ccounts[j]
        self.ccounts = np.concatenate([self.ccounts[:j], self.ccounts[j+1:]])
        
        self.nlabels[self.nlabels == j] = i
        self.nlabels[self.nlabels > j] = self.nlabels[self.nlabels > j] - 1
    
    """
    Abstract the input and output weights of the current layer
    by merging neurons at random until only n are left
    """
    def RandomCompress(self, n, NextLayer):
        
        # keep merging clusters until the given target n is reached
        while len(self.ccounts) > n:
            
            i, j = sample(range(len(self.ccounts)), 2)
            self.MergeNeuronPair(min(i,j), max(i,j), NextLayer)
    
    """
    Execute the compressed layer with a concrete input
    """
    def Execute(self, UpperInput, LowerInput, WithReLU=True):
        
        x = self.IntervalInput(UpperInput, LowerInput)
        
        WupPos, WupNeg = MatrixSignSplit(self.Wup)
        WlowPos, WlowNeg = MatrixSignSplit(self.Wlow)
        
        Wup = np.concatenate([WupPos, WupNeg], axis=1)
        Wlow = np.concatenate([WlowNeg, WlowPos], axis=1)
        
        yup = np.ravel(Wup @ x) + np.ravel(self.bup)
        ylow = np.ravel(Wlow @ x) + np.ravel(self.blow)
        
        yup = yup * self.ccounts
        ylow = ylow * self.ccounts
        
        if WithReLU:
            yup[yup < 0] = 0
            ylow[ylow < 0] = 0
        
        return (yup, ylow)
