# -*- coding: utf-8 -*-
"""
Created on Tue Jul 26 14:01:47 2022

@author: Edoardo
"""

import numpy as np
from algorithms.ginnacer.src.center import CenteredLayer
from algorithms.ginnacer.src.merge import InactiveClusters
from algorithms.ginnacer.src.merge import MergedLayer

def InactiveClustersTest():
    
    InDim = 6
    OutDim = 15
    W = np.random.normal(1, 0.5, size = [OutDim, InDim])
    xc = np.random.rand(InDim)
    
    # force all neurons to be inactive
    b = -W @ xc - 1
    
    labels, Wh, bh = InactiveClusters(W, b, xc)
    
    # test the upper bound correctness
    for i in range(OutDim):
        cluster = labels[i]
        
        assert (W[i,:] <= Wh[cluster,:]).all()
        assert b[i] <= bh[cluster]
    
    # test the cluster inactivity
    yh = Wh @ xc + bh
    assert (yh <= 0).all()

def SoundnessTest(NumRep=50, SignedInput=False):
    
    InDim = 7
    OutDim = 23
    W = np.random.normal(1, 2, size = [OutDim, InDim])
    b = np.random.normal(-1, 2, OutDim)
    c = np.random.rand(InDim)
    
    layer = MergedLayer(W, b, c, SignedInput)
    
    # test for equivalence with random inputs
    for _ in range(NumRep):
        
        if SignedInput:
            x = np.random.normal(0.5, 1, size=InDim)
        else:
            x = np.random.rand(InDim)
        
        yo = CenteredLayer.ExecuteOriginal(layer, x, WithReLU=True)
        yInc, yDec = layer.Execute(x, x)
        
        np.set_printoptions(precision=3)
        np.set_printoptions(suppress=True)
        
        print(x)
        print(np.transpose(np.stack([yo, yInc, yDec, layer.labels])))
        
        assert np.logical_or(yInc > yo, np.isclose(yInc, yo)).all()
        assert np.logical_or(yDec < yo, np.isclose(yDec, yo)).all()

def RunMergeTests():
    
    print(">>> Testing module merge.py")
    
    InactiveClustersTest()
    
    SoundnessTest(SignedInput=False)
    SoundnessTest(SignedInput=True)
    
    print(">>> Done!")

if __name__ == "__main__":
    RunMergeTests()
