# -*- coding: utf-8 -*-
"""
Created on Fri Aug 19 11:53:35 2022

@author: Edoardo
"""

import numpy as np
from algorithms.ginnacer.src.interval import IntervalLayer

"""
Cluster the set of affine operations, but keep them negative at the centroid
Arguments:
    W: weight matrix (row vectors)
    b: bias vector
    xc: centroid vector
"""
def InactiveClusters(W, b, xc):
    
    dim = len(b)
    labels = [i for i in range(dim)]
    
    Wh = []
    bh = []
    
    for i in range(dim):
        if labels[i] == i: # not merged yet
        
            Wm = W[i]
            bm = b[i]
            
            for j in range(i + 1, dim):
                if labels[j] == j: # not merged yet
                
                    Wt = np.maximum(Wm, W[j])
                    bt = np.maximum(bm, b[j])
                    
                    yc = Wt @ xc + bt
                    if yc <= 0:
                        labels[j] = i
                        
                        Wm = Wt
                        bm = bt
            
            # save the cluster parameters
            Wh.append(Wm)
            bh.append(bm)
    
    # map the cluster ids to a contiguous range
    IdMap = {l: i for i, l in enumerate(np.unique(labels))}
    labels = np.array([IdMap[l] for l in labels])
    
    return (labels, np.array(Wh), np.array(bh))

"""
Merge Inc neurons, assuming that they are all inactive at the centroid
"""
class MergedLayer(IntervalLayer):
    
    """
    Compress all layers that contain ReLU activations
    """
    def __init__(self, WeightMatrix, BiasVector, Centroid, SignedInput=False):
        
        IntervalLayer.__init__(self, WeightMatrix, BiasVector, Centroid, SignedInput, LinearOutput=False)
        
        # adapt the centroid to interval input
        FullCentroid = IntervalLayer.IntervalInput(self, Centroid, Centroid)
        
        # cluster the potentials of Inc neurons
        (labels, Wh, bh) = InactiveClusters(self.WpotInc,
                                            self.bpotInc,
                                            FullCentroid)
        self.labels = labels
        self.WpotIncH = Wh
        self.bpotIncH = bh
        
        # reconstruction matrix
        self.RpotInc = np.zeros([len(self.bpotInc), len(self.bpotIncH)])
        for i, l in enumerate(labels):
            self.RpotInc[i, l] = 1
        
        # compute compression stat
        self.stats_inc_neurons = len(self.bpotInc)
        self.stats_inc_clusters = len(self.bpotIncH)
    
    """
    Compute the output of the merged layer for a given input
    """
    def Execute(self, IncInput, DecInput):
        
        FullInput = self.IntervalInput(IncInput, DecInput)
        
        # linear part
        IncOutput = np.ravel(self.WlinInc @ FullInput) + np.ravel(self.blinInc)
        DecOutput = np.ravel(self.WlinDec @ FullInput) + np.ravel(self.blinDec)
        
        # Merged ReLUs for Inc neurons only
        IncPotH = np.ravel(self.WpotIncH @ FullInput) + np.ravel(self.bpotIncH)
        IncReluH = IncPotH * (IncPotH > 0)
        
        # Add the reconstructed Inc ReLUs to the linear part
        IncOutput = IncOutput + self.RpotInc @ IncReluH
        
        return (IncOutput, DecOutput)
