# -*- coding: utf-8 -*-
"""
Created on Fri Oct 21 13:58:49 2022

@author: Edoardo
"""

import numpy as np
import onnx
from onnx import numpy_helper
import matplotlib.pyplot as plt

from algorithms.ginnacer.src.abstract import AbstractNet
from algorithms.elboher.src.abstract import ElboherNet
from algorithms.prabhakar.src.abstract import PrabhakarNet
from algorithms.fastlin.src.abstract import FastLinNet

def extractPadeJacobiWeights(filepath):
    
    # load ONNX model
    nn = onnx.load(filepath)

    # convert all weights to numpy matrices
    W = []
    for params in nn.graph.initializer:
        numpyArray = numpy_helper.to_array(params)
        W.append(np.transpose(numpyArray))
    
    wList = W[0::2]
    bList = W[1::2]
    
    # for some reason Keras stores models in reverse order
    wList.reverse()
    bList.reverse()
    
    return (wList, bList)

def executeOriginalNN(wList, bList, x):
    
    y = x
    for i in range(len(wList)-1):
        if wList[i].shape[1] == 1:
            y = wList[i] * y + bList[i][:,np.newaxis]
        else:
            y = wList[i] @ np.array(y) + bList[i][:,np.newaxis]
        y[y < 0] = 0
    
    y = wList[-1] @ y + bList[-1][:,np.newaxis]
    
    return y

def compareAbstractions(InputPath, OutputPath, Window, Granularity, Centroid):
    
    wList, bList = extractPadeJacobiWeights(InputPath)
    
    # original neural network
    x = np.arange(-Window, +Window, Granularity)
    y = np.zeros(len(x))
    for i, xx in enumerate(x):
        y[i] = executeOriginalNN(wList, bList, xx)
    
    # save original x,y on file
    data = np.transpose(np.stack([x,y]))
    np.savetxt(OutputPath + "original.csv", data, delimiter=",")
    
    # GINNACER abstraction
    Centroid = np.array(Centroid).reshape(1,1)
    MyNet = AbstractNet(wList, bList, Centroid)

    # compression stats
    ClustList = []
    for i, layer in enumerate(MyNet.layers[:-1]):
        relus = layer.stats_inc_neurons
        clusters = layer.stats_inc_clusters
        print("Layer", i + 1,
              "ReLUs", clusters, "/", relus,
              "Size", clusters / relus)
        ClustList.append(clusters)
    
    # other global abstractions
    ElboNet = ElboherNet(wList, bList, ClustList, True)
    PrabNet = PrabhakarNet(wList, bList, ClustList, True)
    
    # train two local FastLin approximations
    FastNet05 = FastLinNet(wList, bList, Centroid + 0.5, Centroid - 0.5)
    FastNet10 = FastLinNet(wList, bList, Centroid + 1.0, Centroid - 1.0)
    
    MyError = np.zeros([2, len(x)])
    PrabError = np.zeros([2, len(x)])
    ElboError = np.zeros([2, len(x)])
    Fast05Error = np.zeros([2, len(x)])
    Fast10Error = np.zeros([2, len(x)])
    
    # compute the abstracted output across the whole range
    for i, xx in enumerate(x):
        xx = np.array(xx).reshape(1,1)
        
        MyError[0,i], MyError[1,i] = MyNet.Execute(xx, xx)[-1]
        PrabError[0,i], PrabError[1,i] = PrabNet.Execute(xx, xx)[-1]
        ElboError[0,i], _, _, ElboError[1,i] = ElboNet.Execute(xx, xx, xx, xx)[-1]
        Fast05Error[0,i], Fast05Error[1,i] = FastNet05.Execute(xx, xx)[-1]
        Fast10Error[0,i], Fast10Error[1,i] = FastNet10.Execute(xx, xx)[-1]
    
    # save abstract x,y on file
    MyData = np.transpose(np.stack([x, MyError[0,:], MyError[1,:]]))
    PrabData = np.transpose(np.stack([x, PrabError[0,:], PrabError[1,:]]))
    ElboData = np.transpose(np.stack([x, ElboError[0,:], ElboError[1,:]]))
    Fast05Data = np.transpose(np.stack([x, Fast05Error[0,:], Fast05Error[1,:]]))
    Fast10Data = np.transpose(np.stack([x, Fast10Error[0,:], Fast10Error[1,:]]))
    np.savetxt(OutputPath + "ginnacer.csv", MyData, delimiter=",")
    np.savetxt(OutputPath + "prabhakar.csv", PrabData, delimiter=",")
    np.savetxt(OutputPath + "elboher.csv", ElboData, delimiter=",")
    np.savetxt(OutputPath + "fastlin_05.csv", Fast05Data, delimiter=",")
    np.savetxt(OutputPath + "fastline_10.csv", Fast10Data, delimiter=",")
    
    # plot global abstractions
    plt.figure()
    plt.plot(x, y, 'k')
    plt.plot(x, MyError[0,:], 'b', x, MyError[1,:], 'b')
    plt.plot(x, PrabError[0,:], 'r', x, PrabError[1,:], 'r')
    plt.plot(x, ElboError[0,:], 'g', x, ElboError[1,:], 'g')
    ax = plt.gca()
    ax.set_xlim([-Window, +Window])
    ax.set_ylim([-100, +100])
    
    # plot local abstractions
    plt.figure()
    plt.plot(x, y, 'k')
    plt.plot(x, MyError[0,:], 'b', x, MyError[1,:], 'b')
    plt.plot(x, Fast05Error[0,:], 'r', x, Fast05Error[1,:], 'r')
    plt.plot(x, Fast10Error[0,:], 'g', x, Fast10Error[1,:], 'g')
    ax = plt.gca()
    ax.set_xlim([-Window, +Window])
    ax.set_ylim([-1, +1])
