# -*- coding: utf-8 -*-
"""
Created on Mon Aug 29 21:22:20 2022

@author: Edoardo
"""

import numpy as np
import onnx
from onnx import numpy_helper
from scipy import stats

from algorithms.ginnacer.src.abstract import AbstractNet

def extractMnistFcWeights(FilePath):
    
    # load AcasXu ONNX model
    nn = onnx.load(FilePath)
    
    ExpectedOrder = ["layers.0", "layers.2", "layers.4", "layers.6", "layers.8", "layers.10", "layers.12"]
    ActualOrder = [node.name for node in nn.graph.initializer]
    
    # convert all weights and biases to numpy matrices
    wList = []
    bList = []
    for name in ExpectedOrder:
        i = ActualOrder.index(name + ".weight")
        j = ActualOrder.index(name + ".bias")
        weights = nn.graph.initializer[i]
        bias = nn.graph.initializer[j]
        wList.append(numpy_helper.to_array(weights))
        bList.append(numpy_helper.to_array(bias))
    
    return (wList, bList)

def MnistFcClusters(InputPath, Centroid):

    wList, bList = extractMnistFcWeights(InputPath)
    
    MyNet = AbstractNet(wList, bList, Centroid)
    
    # compression stats
    ClustList = []
    for i, layer in enumerate(MyNet.layers[:-1]):
        clusters = layer.stats_inc_clusters
        ClustList.append(clusters)
    
    return ClustList

def CompressAllMnistFc(NetPath, CPath, OutPath):

    # OONX file path
    InputPath = NetPath + "mnist-net_256x6.onnx"
    
    # run the compression algorithm over and over
    results = []
    for i in range(1000):
        print(i)
        
        CentroidPath = CPath + "mnist_input_" + str(i) + ".csv"
        Centroid = np.ravel(np.loadtxt(CentroidPath, delimiter=","))
        
        clusters = MnistFcClusters(InputPath, Centroid)
        results.append(clusters)
    
    # post-process the results
    R = np.array(results)
    MinR = np.min(R, axis=0)
    MaxR = np.max(R, axis=0)
    ModeR = np.ravel(stats.mode(R)[0])
    
    data = np.stack([MinR, ModeR, MaxR])
    np.savetxt(OutPath + "mnist_fc_compression_sweep.csv", data, fmt='%i', delimiter=",")
