# -*- coding: utf-8 -*-
"""
Created on Tue Jul 26 13:47:30 2022

@author: Edoardo
"""

import numpy as np
from algorithms.ginnacer.src.center import CenteredLayer
from algorithms.ginnacer.src.interval import MatrixSignSplit
from algorithms.ginnacer.src.interval import IntervalLayer

def MatrixSignSplitTest():
    
    M = np.array([[1, 2, -3],
                  [0, 0, 1],
                  [-1, -8, 6]])
    
    mPos, mNeg = MatrixSignSplit(M)
    
    assert (mPos == np.array([[1,2,0],[0,0,1],[0,0,6]])).all(), "Wrong positive matrix"
    assert (mNeg == np.array([[0,0,-3],[0,0,0],[-1,-8,0]])).all(), "Wrong negative matrix"

def EquivalenceTest(NumRep=50, SignedInput=False, LinearOutput=False):
    
    InDim = 7
    OutDim = 11
    W = np.random.normal(1, 2, size = [OutDim, InDim])
    b = np.random.normal(-1, 2, OutDim)
    c = np.random.rand(InDim)
    
    layer = IntervalLayer(W, b, c, SignedInput, LinearOutput)
    
    # test for equivalence with random inputs
    for _ in range(NumRep):
        
        if SignedInput:
            x = np.random.normal(0.5, 1, size=InDim)
        else:
            x = np.random.rand(InDim)
        
        yo = CenteredLayer.ExecuteOriginal(layer, x, WithReLU=not LinearOutput)
        yInc, yDec = layer.Execute(x, x)
        
        assert np.allclose(yo, yInc)
        assert np.allclose(yo, yDec)

def SoundnessTest(NumRep=50, SignedInput=False, LinearOutput=False):
    
    InDim = 7
    OutDim = 11
    W = np.random.normal(1, 2, size = [OutDim, InDim])
    b = np.random.normal(-1, 2, OutDim)
    c = np.random.rand(InDim)
    
    layer = IntervalLayer(W, b, c, SignedInput, LinearOutput)
    
    # test for equivalence with random inputs
    for _ in range(NumRep):
        
        if SignedInput:
            x = np.random.normal(0.5, 1, size=InDim)
        else:
            x = np.random.rand(InDim)
        
        yo = CenteredLayer.ExecuteOriginal(layer, x, WithReLU=not LinearOutput)
        yInc, yDec = layer.Execute(x + 0.1, x)
        
        assert np.logical_or(yInc > yo, np.isclose(yInc, yo)).all()
        assert np.logical_or(yDec < yo, np.isclose(yDec, yo)).all()

def RunIntervalTests():
    
    print(">>> Testing module interval.py")
    
    MatrixSignSplitTest()
    
    EquivalenceTest(SignedInput=False, LinearOutput=False)
    EquivalenceTest(SignedInput=True, LinearOutput=False)
    EquivalenceTest(SignedInput=False, LinearOutput=True)
    EquivalenceTest(SignedInput=True, LinearOutput=True)
    
    SoundnessTest(SignedInput=False, LinearOutput=False)
    SoundnessTest(SignedInput=True, LinearOutput=False)
    SoundnessTest(SignedInput=False, LinearOutput=True)
    SoundnessTest(SignedInput=True, LinearOutput=True)
  
    print(">>> Done!")

if __name__ == "__main__":
    RunIntervalTests()
