# -*- coding: utf-8 -*-
"""
Created on Tue Jul 26 14:01:47 2022

@author: Edoardo
"""

import numpy as np
from algorithms.elboher.src.merge import MergeUpper
from algorithms.elboher.src.merge import MergeLower
from algorithms.elboher.src.merge import ChoosePairToMerge
from algorithms.elboher.src.merge import MergedLayer

def MergeUpperTest():
    
    InDim = 9
    OutDim = 6
    W = np.random.normal(1, 2, size = [OutDim, InDim])
    b = np.random.normal(-1, 1, OutDim)
    
    Wup, bup = MergeUpper(W, b)
    
    assert np.allclose(Wup, np.max(W, axis=0))
    assert np.allclose(bup, np.max(b))

def MergeLowerTest():
    
    InDim = 9
    OutDim = 6
    W = np.random.normal(1, 2, size = [OutDim, InDim])
    b = np.random.normal(-1, 1, OutDim)
    
    Wlow, blow = MergeLower(W, b)
    
    assert np.allclose(Wlow, np.min(W, axis=0))
    assert np.allclose(blow, np.min(b))

def ChoosePairToMergeTest():
    
    InDim = 9
    OutDim = 6
    W = np.random.normal(1, 2, size = [OutDim, InDim])
    b = np.random.normal(-1, 1, OutDim)
    
    p, q = ChoosePairToMerge(W, b)
    
    best_m = np.max(np.abs(np.concatenate([W[p,:]-W[q,:], [b[p]-b[q]]])))
    
    for i in range(OutDim):
        for j in range(i+1,OutDim):
            m = np.max(np.abs(np.concatenate([W[i,:]-W[j,:], [b[i]-b[j]]])))
            
            assert np.logical_or(m > best_m, np.isclose(m, best_m)).all()

def ClusterNeuronsTest(NumRep=50):
    
    InDim = 9
    OutDim = 6
    W = np.random.normal(1, 2, size = [OutDim, InDim])
    b = np.random.normal(-1, 1, OutDim)
    
    n = int(OutDim / 2)
    
    Wup, bup = MergedLayer.ClusterNeurons(W, b, MergeUpper, n)
    Wlow, blow = MergedLayer.ClusterNeurons(W, b, MergeLower, n + 1)
    
    assert Wup.shape == W.shape
    assert Wlow.shape == W.shape
    assert len(bup) == len(b)
    assert len(blow) == len(b)
    
    # check soundness of weights and biases
    assert np.logical_or(Wup > W, np.isclose(Wup, W)).all()
    assert np.logical_or(Wlow < W, np.isclose(Wlow, W)).all()
    assert np.logical_or(bup > b, np.isclose(bup, b)).all()
    assert np.logical_or(blow < b, np.isclose(blow, b)).all()
    
    # approximate check of the number of clusters
    Cup = np.unique(Wup, axis=0)
    Clow = np.unique(Wlow, axis=0)
    dup = np.unique(bup)
    dlow = np.unique(blow)
    
    assert Cup.shape[0] == n
    assert Clow.shape[0] == n + 1
    assert len(dup) == n
    assert len(dlow) == n + 1

def EqualityTest(NumRep=50, SignedInput=False):
    
    InDim = 7
    OutDim = 23
    W = np.random.normal(1, 2, size = [OutDim, InDim])
    b = np.random.normal(-1, 2, OutDim)
    
    layer = MergedLayer(W, b, OutDim, OutDim, OutDim, OutDim, SignedInput)
    
    # test equality against random inputs
    for _ in range(NumRep):
        if SignedInput:
            x = np.random.normal(1, 2, size=InDim)
        else:
            x = np.random.rand(InDim)
        
        y = W @ x + b
        y[y < 0] = 0
        
        yel = layer.Execute(x, x, x, x, WithReLU=True)
        
        assert np.allclose(yel[0], y)
        assert np.allclose(yel[1], y)
        assert np.allclose(yel[2], y)
        assert np.allclose(yel[3], y)

def IncDecTest(NumRep=50, SignedInput=False):
    
    InDim = 7
    OutDim = 23
    W = np.random.normal(1, 2, size = [OutDim, InDim])
    b = np.random.normal(-1, 2, OutDim)
    
    layer = MergedLayer(W, b, 5, 5, 6, 6, SignedInput)
    
    # test equality against random inputs
    for _ in range(NumRep):
        if SignedInput:
            x = np.random.normal(1, 2, size=InDim)
        else:
            x = np.random.rand(InDim)
        
        yel = layer.Execute(x, x, x, x, WithReLU=True)
        
        assert np.allclose(yel[0], yel[1])
        assert np.allclose(yel[2], yel[3])

def SoundnessTest(NumRep=50, SignedInput=False):
    
    InDim = 7
    OutDim = 23
    W = np.random.normal(1, 2, size = [OutDim, InDim])
    b = np.random.normal(-1, 2, OutDim)
    
    layer = MergedLayer(W, b, 5, 6, 8, 3, SignedInput)
    
    # test equality against random inputs
    for _ in range(NumRep):
        if SignedInput:
            x = np.random.normal(1, 2, size=InDim)
        else:
            x = np.random.rand(InDim)
        
        y = W @ x + b
        y[y < 0] = 0
        
        yel = layer.Execute(x, x, x, x, WithReLU=True)
        
        assert np.logical_or(yel[0] > y, np.isclose(yel[0], y)).all()
        assert np.logical_or(yel[1] > y, np.isclose(yel[1], y)).all()
        assert np.logical_or(yel[2] < y, np.isclose(yel[2], y)).all()
        assert np.logical_or(yel[3] < y, np.isclose(yel[3], y)).all()

def RunMergeTests():
    
    print(">>> Testing module merge.py")
    
    MergeUpperTest()
    MergeLowerTest()
    ChoosePairToMergeTest()
    ClusterNeuronsTest()
    
    EqualityTest(SignedInput=False)
    EqualityTest(SignedInput=True)
    IncDecTest(SignedInput=False)
    IncDecTest(SignedInput=True)
    SoundnessTest(SignedInput=False)
    SoundnessTest(SignedInput=True)
    
    print(">>> Done!")

if __name__ == "__main__":
    RunMergeTests()
