# -*- coding: utf-8 -*-
"""
Created on Mon Aug 29 11:29:07 2022

@author: Edoardo
"""

import numpy as np
import onnx
from onnx import numpy_helper
import matplotlib.pyplot as plt

from algorithms.ginnacer.src.abstract import AbstractNet
from algorithms.elboher.src.abstract import ElboherNet
from algorithms.prabhakar.src.abstract import PrabhakarNet
from algorithms.fastlin.src.abstract import FastLinNet

def extractAcasXuWeights(FilePath):
    
    # load AcasXu ONNX model
    nn = onnx.load(FilePath)

    # convert all weights to numpy matrices
    W = []
    for params in nn.graph.initializer:
        numpyArray = numpy_helper.to_array(params)
        W.append(np.transpose(numpyArray))
    
    wList = W[1::2]
    bList = W[2::2]
    
    return (wList, bList)

def RunComparison(InputPath, OutputPath, Centroid, DistList, NumSamples):

    wList, bList = extractAcasXuWeights(InputPath)

    MyNet = AbstractNet(wList, bList, Centroid)

    # compression stats
    ClustList = []
    for i, layer in enumerate(MyNet.layers[:-1]):
        relus = layer.stats_inc_neurons
        clusters = layer.stats_inc_clusters
        print("Layer", i + 1,
              "ReLUs", clusters, "/", relus,
              "Size", clusters / relus)
        ClustList.append(clusters)

    ElboNet = ElboherNet(wList, bList, ClustList, True)
    PrabNet = PrabhakarNet(wList, bList, ClustList, True)

    MyMax = np.zeros(len(DistList))
    ElboMax = np.zeros(len(DistList))
    PrabMax = np.zeros(len(DistList))
    FastMax = np.zeros(len(DistList))

    for i, Dist in enumerate(DistList):
        print("Distance:", Dist)
        
        # train a local FastLin approximation
        FastNet = FastLinNet(wList, bList, Centroid + Dist, Centroid - Dist)
        
        MyError = np.zeros(NumSamples)
        PrabError = np.zeros(NumSamples)
        ElboError = np.zeros(NumSamples)
        FastError = np.zeros(NumSamples)
        
        for j in range(NumSamples):
            
            # extract random point on the surface of d-sized hypercube
            Vec = np.random.normal(0, 1, len(Centroid))
            Vec = Vec * Dist / np.max(np.abs(Vec))
            x = Centroid + Vec
                
            # compute the output of my abstraction
            pot = MyNet.Execute(x, x)
            MyUp = pot[-1][0]
            MyLow = pot[-1][-1]
            
            # compute the output of Elboher's abstraction
            pot = ElboNet.Execute(x, x, x, x)
            ElboUp = pot[-1][0]
            ElboLow = pot[-1][-1]
            
            # compute the output of Prabhakar's abstraction
            pot = PrabNet.Execute(x, x)
            PrabUp = pot[-1][0]
            PrabLow = pot[-1][-1]
            
            # compute the FastLin abstraction output
            pot = FastNet.Execute(x, x)
            FastUp = pot[-1][0]
            FastLow = pot[-1][-1]
                
            # keep track of max output error
            MyError[j] = np.max(MyUp - MyLow)
            ElboError[j] = np.max(ElboUp - ElboLow)
            PrabError[j] = np.max(PrabUp - PrabLow)
            FastError[j] = np.max(FastUp - FastLow)

        MyMax[i] = np.max(MyError)
        ElboMax[i] = np.max(ElboError)
        PrabMax[i] = np.max(PrabError)
        FastMax[i] = np.max(FastError)

    # save abstract x,y on file
    MyData = np.transpose(np.stack([DistList, MyMax]))
    ElboData = np.transpose(np.stack([DistList, ElboMax]))
    PrabData = np.transpose(np.stack([DistList, PrabMax]))
    FastData = np.transpose(np.stack([DistList, FastMax]))
    np.savetxt(OutputPath + "ginnacer.csv", MyData, delimiter=",")
    np.savetxt(OutputPath + "elboher.csv", ElboData, delimiter=",")
    np.savetxt(OutputPath + "prabhakar.csv", PrabData, delimiter=",")
    np.savetxt(OutputPath + "fastlin.csv", FastData, delimiter=",")

    # plot result
    plt.loglog(DistList, MyMax, DistList, ElboMax, DistList, PrabMax, DistList, FastMax)
