# -*- coding: utf-8 -*-
"""
Created on Tue Aug 23 16:28:41 2022

@author: Edoardo
"""

import numpy as np
from algorithms.prabhakar.src.interval import IntervalLayer

def RightAbstractionTest():
    
    InDim = 9
    OutDim = 6
    W = np.random.normal(1, 2, size = [OutDim, InDim])
    b = np.random.normal(-1, 1, OutDim)
    
    W24, b24 = IntervalLayer.RightAbstraction(W, b, 2, 4, np.maximum)
    W15, b15 = IntervalLayer.RightAbstraction(W, b, 1, 5, np.minimum)
    
    assert W24.shape == (5, 9)
    assert W15.shape == (5, 9)
    assert len(b24) == 5
    assert len(b15) == 5
    
    assert np.allclose(W24[4,:], W[5,:])
    assert np.allclose(b24[4], b[5])
    
    assert np.logical_or(W24[2,:] > W[2,:], np.isclose(W24[2,:], W[2,:])).all()
    assert np.logical_or(W24[2,:] > W[4,:], np.isclose(W24[2,:], W[4,:])).all()
    assert np.logical_or(b24[2] > b[2], np.isclose(b24[2], b[2])).all()
    assert np.logical_or(b24[2] > b[4], np.isclose(b24[2], b[4])).all()
    
    assert np.logical_or(W15[1,:] < W[1,:], np.isclose(W15[1,:], W[1,:])).all()
    assert np.logical_or(W15[1,:] < W[5,:], np.isclose(W15[1,:], W[5,:])).all()
    assert np.logical_or(b15[1] < b[1], np.isclose(b15[1], b[1])).all()
    assert np.logical_or(b15[1] < b[5], np.isclose(b15[1], b[5])).all()

def LeftAbstractionTest():
    
    InDim = 9
    OutDim = 6
    W = np.random.normal(1, 2, size = [OutDim, InDim])
    
    W24 = IntervalLayer.LeftAbstraction(W, 2, 4, np.maximum)
    W15 = IntervalLayer.LeftAbstraction(W, 1, 5, np.minimum)
    
    assert W24.shape == (6, 8)
    assert W15.shape == (6, 8)
    
    assert np.allclose(W24[:,4], W[:,5])
    assert np.allclose(W24[:,5], W[:,6])
    assert np.allclose(W24[:,6], W[:,7])
    assert np.allclose(W24[:,7], W[:,8])
    
    assert np.allclose(W15[:,5], W[:,6])
    assert np.allclose(W15[:,6], W[:,7])
    assert np.allclose(W15[:,7], W[:,8])
    
    assert np.logical_or(W24[:,2] > W[:,2], np.isclose(W24[:,2], W[:,2])).all()
    assert np.logical_or(W24[:,2] > W[:,4], np.isclose(W24[:,2], W[:,4])).all()
    
    assert np.logical_or(W15[:,1] < W[:,1], np.isclose(W15[:,1], W[:,1])).all()
    assert np.logical_or(W15[:,1] < W[:,5], np.isclose(W15[:,1], W[:,5])).all()

def MergeNeuronPairTest():
    
    InDim = 9
    OutDim = 6
    NextDim = 7
    Wcur = np.random.normal(1, 2, size = [OutDim, InDim])
    Wnxt = np.random.normal(1, 2, size = [NextDim, OutDim])
    bcur = np.random.normal(-1, 1, OutDim)
    bnxt = np.random.normal(-1, 1, NextDim)
    
    CurrLayer = IntervalLayer(Wcur, bcur)
    NextLayer = IntervalLayer(Wnxt, bnxt)
    
    CurrLayer.MergeNeuronPair(0, 3, NextLayer)
    
    assert CurrLayer.Wup.shape == (5, 9)
    assert NextLayer.Wup.shape == (7, 5)
    assert len(CurrLayer.bup) == 5
    assert len(NextLayer.bup) == 7
    
    assert CurrLayer.Wlow.shape == (5, 9)
    assert NextLayer.Wlow.shape == (7, 5)
    assert len(CurrLayer.blow) == 5
    assert len(NextLayer.blow) == 7
    
    assert CurrLayer.ccounts[0] == 2
    assert np.min(CurrLayer.ccounts) == 1
    assert np.sum(CurrLayer.ccounts) == 6
    
    assert (CurrLayer.nlabels == [0, 1, 2, 0, 3, 4]).all()

def RandomCompressTest(NumRep=50):
    
    InDim = 9
    OutDim = 6
    NextDim = 7
    Wcur = np.random.normal(1, 2, size = [OutDim, InDim])
    Wnxt = np.random.normal(1, 2, size = [NextDim, OutDim])
    bcur = np.random.normal(-1, 1, OutDim)
    bnxt = np.random.normal(-1, 1, NextDim)
    
    for _ in range(NumRep):
    
        CurrLayer = IntervalLayer(Wcur, bcur)
        NextLayer = IntervalLayer(Wnxt, bnxt)
        
        CurrLayer.RandomCompress(5, NextLayer)
        
        assert len(CurrLayer.bup) == 5
        assert len(CurrLayer.blow) == 5
        assert CurrLayer.Wup.shape[0] == 5
        assert CurrLayer.Wlow.shape[0] == 5
        
        assert np.sum(CurrLayer.ccounts) == 6
        
        for i, c in enumerate(CurrLayer.nlabels):
            assert (Wcur[i,:] <= CurrLayer.Wup[c,:]).all()
            assert (Wcur[i,:] >= CurrLayer.Wlow[c,:]).all()
            assert bcur[i] <= CurrLayer.bup[c]
            assert bcur[i] >= CurrLayer.blow[c]
            
            assert (Wnxt[:,i] <= NextLayer.Wup[:,c]).all()
            assert (Wnxt[:,i] >= NextLayer.Wlow[:,c]).all()

def EqualityTest(NumRep=50, SignedInput=False):
    
    InDim = 9
    OutDim = 15
    Wcur = np.random.normal(1, 2, size = [OutDim, InDim])
    bcur = np.random.normal(-1, 1, OutDim)
    
    CurrLayer = IntervalLayer(Wcur, bcur, SignedInput)
    
    # test equality against random inputs
    for _ in range(NumRep):
        if SignedInput:
            x = np.random.normal(1, 2, size=InDim)
        else:
            x = np.random.rand(InDim)
        
        y = Wcur @ x + bcur
        y[y < 0] = 0
        
        ypr = CurrLayer.Execute(x, x, WithReLU=True)
        
        assert np.allclose(ypr[0], y)
        assert np.allclose(ypr[1], y)

def SoundnessTest(NumRep=50, SignedInput=False):
    
    InDim = 9
    OutDim = 15
    NextDim = 7
    Wcur = np.random.normal(1, 2, size = [OutDim, InDim])
    Wnxt = np.random.normal(1, 2, size = [NextDim, OutDim])
    bcur = np.random.normal(-1, 1, OutDim)
    bnxt = np.random.normal(-1, 1, NextDim)
    
    CurrLayer = IntervalLayer(Wcur, bcur, SignedInput)
    NextLayer = IntervalLayer(Wnxt, bnxt)
    
    CurrLayer.RandomCompress(8, NextLayer)
    
    # test soundness against random inputs
    for _ in range(NumRep):
        if SignedInput:
            x = np.random.normal(1, 2, size=InDim)
        else:
            x = np.random.rand(InDim)
        
        y = Wcur @ x + bcur
        y[y < 0] = 0
        z = Wnxt @ y + bnxt
        
        ypr = CurrLayer.Execute(x, x, WithReLU=True)
        zpr = NextLayer.Execute(*ypr, WithReLU=False)
        
        compressed = (CurrLayer.ccounts > 1)
        uncompressed = (CurrLayer.ccounts == 1)
        assert np.allclose(ypr[0][uncompressed], ypr[1][uncompressed])
        assert (ypr[0][compressed] >= ypr[1][compressed]).all()
        
        assert np.logical_or(zpr[0] > z, np.isclose(zpr[0], z)).all()
        assert np.logical_or(zpr[1] < z, np.isclose(zpr[1], z)).all()

def RunIntervalTests():
    
    print(">>> Testing module interval.py")
    
    RightAbstractionTest()
    LeftAbstractionTest()
    MergeNeuronPairTest()
    RandomCompressTest()
    
    EqualityTest(SignedInput=False)
    EqualityTest(SignedInput=True)
    SoundnessTest(SignedInput=False)
    SoundnessTest(SignedInput=True)
    
    print(">>> Done!")

if __name__ == "__main__":
    RunIntervalTests()
