# -*- coding: utf-8 -*-
"""
Created on Mon Aug 22 13:37:19 2022

@author: Edoardo
"""

import numpy as np
from algorithms.elboher.src.partition import PartitionedLayer

"""
Compute an upper bound on a set of affine operations for non-negative inputs
Arguments:
    W: weight matrix (row vectors)
    b: bias vector
"""
def MergeUpper(W, b):
    
    wUp = np.max(W, axis=0)
    bUp = np.max(b)
    
    return (wUp, bUp)

"""
Compute a lower bound on a set of affine operations for non-negative inputs
Arguments:
    W: weight matrix (row vectors)
    b: bias vector
"""
def MergeLower(W, b):
    
    wLow = np.min(W, axis=0)
    bLow = np.min(b)
    
    return (wLow, bLow)

"""
Brute force the best pair to merge (minimum infinite norm)
"""
def ChoosePairToMerge(W, b):
    
    dim = len(b)
    
    best_m = None
    best_i = 0
    best_j = 0
    
    for i in range(dim):
        for j in range(i+1,dim):
            d = np.concatenate([W[i,:] - W[j,:], [b[i] - b[j]]])
            m = np.max(np.abs(d))
            
            if best_m is None or m < best_m:
                best_m = m
                best_i = i
                best_j = j
    
    return (best_i, best_j)

"""
Class that holds the parameters of a clustered layer
"""
class MergedLayer(PartitionedLayer):
    
    """
    Cluster neurons with similar activation patterns together
    Let the user specify how many Inc and Dec clusters to have
    """
    def __init__(self, WeightMatrix, BiasVector, nIncPos, nIncNeg, nDecPos, nDecNeg, SignedInput=False):
        PartitionedLayer.__init__(self, WeightMatrix, BiasVector, SignedInput)
        
        # cluster inc neurons
        self.WIncPosR, self.bIncPosR = MergedLayer.ClusterNeurons(self.WIncPos,
                                                                  self.bIncPos,
                                                                  MergeUpper,
                                                                  nIncPos)
        self.WIncNegR, self.bIncNegR = MergedLayer.ClusterNeurons(self.WIncNeg,
                                                                  self.bIncNeg,
                                                                  MergeUpper,
                                                                  nIncNeg)
        
        # cluster dec neurons
        self.WDecPosR, self.bDecPosR = MergedLayer.ClusterNeurons(self.WDecPos,
                                                                  self.bDecPos,
                                                                  MergeLower,
                                                                  nDecPos)
        self.WDecNegR, self.bDecNegR = MergedLayer.ClusterNeurons(self.WDecNeg,
                                                                  self.bDecNeg,
                                                                  MergeLower,
                                                                  nDecNeg)
    
    """
    Cluster similar neurons together, return over- or under-approximated W and b
    """
    def ClusterNeurons(W, b, MergeOp, n):
        
        labels = np.array([i for i in range(len(b))])
        Wm = W.copy()
        bm = b.copy()
        
        # merge neurons until only n clusters are left
        while len(bm) > n:
            i, j = ChoosePairToMerge(Wm, bm)
            
            # override i with the cluster weights and biases
            Wm[i,:], bm[i] = MergeOp(np.stack([Wm[i,:], Wm[j,:]]),
                                     np.array([bm[i] ,bm[j]]))
            
            # erase j
            # Wm[j:,:] = Wm[j+1:,:]
            # bm[j:] = bm[j+1:]
            
            Wm = np.concatenate([Wm[:j,:], Wm[j+1:,:]], axis=0)
            bm = np.concatenate([bm[:j], bm[j+1:]])
            
            # keep track of the indices
            labels[labels == j] = i
            labels[labels > j] = labels[labels > j] - 1
        
        # restore matrix W by duplicating the cluster weights and biases
        Wrec = Wm[labels,:]
        brec = bm[labels]
        
        return (Wrec, brec)
    
    """
    Execute the compressed layer with a concrete input
    """
    def Execute(self, inIncPos, inIncNeg, inDecPos, inDecNeg, WithReLU=True):
        
        x = self.PartitionedInput(inIncPos, inIncNeg, inDecPos, inDecNeg)
        
        outIncPos = np.ravel(self.WIncPosR @ x) + np.ravel(self.bIncPosR)
        outIncNeg = np.ravel(self.WIncNegR @ x) + np.ravel(self.bIncNegR)
        outDecPos = np.ravel(self.WDecPosR @ x) + np.ravel(self.bDecPosR)
        outDecNeg = np.ravel(self.WDecNegR @ x) + np.ravel(self.bDecNegR)
        
        if WithReLU:
            outIncPos[outIncPos < 0] = 0
            outIncNeg[outIncNeg < 0] = 0
            outDecPos[outDecPos < 0] = 0
            outDecNeg[outDecNeg < 0] = 0
        
        return (outIncPos, outIncNeg, outDecPos, outDecNeg)
    