# -*- coding: utf-8 -*-
"""
Created on Mon Aug 29 15:39:02 2022

@author: Edoardo
"""

import numpy as np
import onnx
from onnx import numpy_helper
from scipy import stats

from algorithms.ginnacer.src.abstract import AbstractNet

def extractAcasXuWeights(FilePath):
    
    # load AcasXu ONNX model
    nn = onnx.load(FilePath)

    # convert all weights to numpy matrices
    W = []
    for params in nn.graph.initializer:
        numpyArray = numpy_helper.to_array(params)
        W.append(np.transpose(numpyArray))
    
    wList = W[1::2]
    bList = W[2::2]
    
    return (wList, bList)

def AcasXuClusters(InputPath, Centroid):

    wList, bList = extractAcasXuWeights(InputPath)
    
    MyNet = AbstractNet(wList, bList, Centroid)
    
    # compression stats
    ClustList = []
    for i, layer in enumerate(MyNet.layers[:-1]):
        clusters = layer.stats_inc_clusters
        ClustList.append(clusters)
    
    return ClustList

def CompressAllAcasXu(NetPath, CPath, OutPath):

    # all 45 AcasXu ONNX file paths
    InputPaths = []
    for i in range(5):
        for j in range(9):
            path = NetPath + "ACASXU_run2a_" + str(i+1) + "_" + str(j+1) + "_batch_2000.onnx"
            InputPaths.append(path)
    
    # run the compression algorithm over and over
    results = []
    for InputPath in InputPaths:
        print(InputPath)
        
        # centroids derived from the input regions of the SV-COMP safety properties
        for i in range(10):
            
            CentroidPath = CPath + "acas_xu_input_" + str(i) + ".csv"
            Centroid = np.ravel(np.loadtxt(CentroidPath, delimiter=","))
            
            clusters = AcasXuClusters(InputPath, Centroid)
            results.append(clusters)
    
    # post-process the results
    R = np.array(results)
    MinR = np.min(R, axis=0)
    MaxR = np.max(R, axis=0)
    ModeR = np.ravel(stats.mode(R)[0])
    
    data = np.stack([MinR, ModeR, MaxR])
    np.savetxt(OutPath + "acas_xu_compression_sweep.csv", data, fmt='%i', delimiter=",")
