# -*- coding: utf-8 -*-
"""
Created on Wed Aug 24 14:30:12 2022

@author: Edoardo
"""

import numpy as np
from algorithms.fastlin.src.bounds import MatrixSignSplit
from algorithms.fastlin.src.bounds import BoundedLayer

def MatrixSignSplitTest():
    
    M = np.array([[1, 2, -3],
                  [0, 0, 1],
                  [-1, -8, 6]])
    
    mPos, mNeg = MatrixSignSplit(M)
    
    assert (mPos == np.array([[1,2,0],[0,0,1],[0,0,6]])).all(), "Wrong positive matrix"
    assert (mNeg == np.array([[0,0,-3],[0,0,0],[-1,-8,0]])).all(), "Wrong negative matrix"

def EqualityTest(NumRep=50):
    
    InDim = 7
    OutDim = 23
    W = np.random.normal(1, 2, size = [OutDim, InDim])
    b = np.random.normal(-1, 2, OutDim)
    
    xmax = np.random.normal(1, 2, InDim)
    xmin = xmax - (np.random.normal(0, 0.25, InDim) ** 2)
    
    layer = BoundedLayer(W, b, xmax, xmin)
    
    # test equality against random inputs
    for _ in range(NumRep):
        x = np.random.normal(1, 2, size=InDim)
        
        y = W @ x + b
        
        ylin = layer.Execute(x, x, WithReLU=False)
        
        assert np.allclose(ylin[0], y)
        assert np.allclose(ylin[1], y)

def IntervalTest(NumRep=50):
    
    InDim = 7
    OutDim = 23
    W = np.random.normal(1, 2, size = [OutDim, InDim])
    b = np.random.normal(-1, 2, OutDim)
    
    xmax = np.random.normal(1, 2, InDim)
    xmin = xmax - (np.random.normal(0, 0.25, InDim) ** 2)
    
    layer = BoundedLayer(W, b, xmax, xmin)
    
    # test equality against random inputs
    for _ in range(NumRep):
        x = np.random.normal(1, 2, size=InDim)
        e = np.random.normal(0, 0.1, InDim) ** 2
        
        ylin = layer.Execute(x, x - e, WithReLU=False)
        
        assert np.logical_or(ylin[0] > ylin[1], np.isclose(ylin[0], ylin[1])).all()

def SoundnessTest(NumRep=50):
    
    InDim = 7
    OutDim = 23
    W = np.random.normal(1, 2, size = [OutDim, InDim])
    b = np.random.normal(-1, 2, OutDim)
    
    xmax = np.random.normal(-1, 2, InDim)
    xmin = xmax - (np.random.normal(0, 1, InDim) ** 2)
    
    layer = BoundedLayer(W, b, xmax, xmin)
    
    # test equality against random inputs
    for _ in range(NumRep):
        e = np.random.rand(InDim)
        x = xmin + e * (xmax - xmin)
        
        y = W @ x + b
        y[y < 0] = 0
        
        ylin = layer.Execute(x, x, WithReLU=True)
        
        print(np.transpose(np.stack([xmin, x, xmax])))
        print(np.transpose(np.stack([ylin[1], y, ylin[0]])))
        
        assert np.logical_or(ylin[0] > y, np.isclose(ylin[0], y)).all()
        assert np.logical_or(ylin[1] < y, np.isclose(ylin[1], y)).all()

def RunBoundsTests():
    
    print(">>> Testing module bounds.py")
    
    MatrixSignSplitTest()
    
    EqualityTest()
    IntervalTest()
    SoundnessTest()
    
    print(">>> Done!")

if __name__ == "__main__":
    RunBoundsTests()
