# -*- coding: utf-8 -*-
"""
Created on Tue Aug 30 13:27:40 2022

@author: Edoardo
"""

import numpy as np
import onnx
from onnx import numpy_helper
from scipy import stats

from algorithms.ginnacer.src.abstract import AbstractNet

def extractToyAdmosWeights(FilePath):
    
    # load AcasXu ONNX model
    nn = onnx.load(FilePath)
    
    WeightsNames = ['const_fold_opt__129',
                    'const_fold_opt__121',
                    'const_fold_opt__126',
                    'const_fold_opt__127',
                    'const_fold_opt__120',
                    'const_fold_opt__124',
                    'const_fold_opt__123',
                    'const_fold_opt__122',
                    'const_fold_opt__125',
                    'const_fold_opt__128']
    BiasNames = ['functional_1/dense/BiasAdd/ReadVariableOp/resource',
                 'functional_1/dense_1/BiasAdd/ReadVariableOp/resource',
                 'functional_1/dense_2/BiasAdd/ReadVariableOp/resource',
                 'functional_1/dense_3/BiasAdd/ReadVariableOp/resource',
                 'functional_1/dense_4/BiasAdd/ReadVariableOp/resource',
                 'functional_1/dense_5/BiasAdd/ReadVariableOp/resource',
                 'functional_1/dense_6/BiasAdd/ReadVariableOp/resource',
                 'functional_1/dense_7/BiasAdd/ReadVariableOp/resource',
                 'functional_1/dense_8/BiasAdd/ReadVariableOp/resource',
                 'functional_1/dense_9/BiasAdd/ReadVariableOp/resource']
    NameList = [node.name for node in nn.graph.initializer]
    
    # convert all weights into numpy matrices
    wList = []
    for name in WeightsNames:
        i = NameList.index(name)
        node = nn.graph.initializer[i]
        weights = numpy_helper.to_array(node)
        wList.append(np.transpose(weights))
    
    # convert all biases into numpy matrices
    bList = []
    for name in BiasNames:
        i = NameList.index(name)
        node = nn.graph.initializer[i]
        bias = numpy_helper.to_array(node)
        bList.append(bias)
    
    return (wList, bList)

def ToyAdmosClusters(InputPath, Centroid):

    wList, bList = extractToyAdmosWeights(InputPath)
    
    MyNet = AbstractNet(wList, bList, Centroid)
    
    # compression stats
    ClustList = []
    for i, layer in enumerate(MyNet.layers[:-1]):
        clusters = layer.stats_inc_clusters
        ClustList.append(clusters)
    
    return ClustList

def CompressAllToyAdmos(NetPath, CPath, OutPath):

    # OONX file path
    InputPath = NetPath + "ad01_fp32.onnx"
    
    # run the compression algorithm over and over
    results = []
    for i in range(1000):
        print(i)
        
        CentroidPath = CPath + "toy_admos_input_" + str(i) + ".csv"
        Centroid = np.ravel(np.loadtxt(CentroidPath, delimiter=","))
        
        clusters = ToyAdmosClusters(InputPath, Centroid)
        results.append(clusters)
    
    # post-process the results
    R = np.array(results)
    MinR = np.min(R, axis=0)
    MaxR = np.max(R, axis=0)
    ModeR = np.ravel(stats.mode(R)[0])
    
    data = np.stack([MinR, ModeR, MaxR])
    np.savetxt(OutPath + "toy_admos_compression_sweep.csv", data, fmt='%i', delimiter=",")
