# -*- coding: utf-8 -*-
"""
Created on Tue Aug  9 15:28:45 2022

@author: Edoardo
"""

import numpy as np
from algorithms.ginnacer.src.center import CenteredLayer
       
def EquivalenceTest(NumRep=50): 
    
    InDim = 5
    OutDim = 6
    W = np.random.normal(1, 2, size = [OutDim, InDim])
    b = np.random.normal(-1, 2, OutDim)
    c = np.random.rand(InDim)
    
    layer = CenteredLayer(W, b, c)
    
    # test for equivalence with random inputs
    for _ in range(NumRep):
        x = np.random.normal(0.5, 1, size=InDim)
        
        yo = layer.ExecuteOriginal(x, WithReLU=True)
        yc = layer.ExecuteCentered(x)
        
        assert np.allclose(yo, yc)

def RunCenterTests():
    
    print(">>> Testing module center.py")
    
    EquivalenceTest()
    
    print(">>> Done!")

if __name__ == "__main__":
    RunCenterTests()
