# -*- coding: utf-8 -*-
"""
Created on Fri Oct 21 14:23:08 2022

@author: Edoardo
"""

import numpy as np

from experiments.motivation.pade_jacobi import compareAbstractions

from experiments.compression.acas_xu_compression import CompressAllAcasXu
from experiments.compression.mnist_fc_compression import CompressAllMnistFc
from experiments.compression.toy_admos_compression import CompressAllToyAdmos

import experiments.distance.acas_xu_distance as acas_dist
import experiments.distance.mnist_fc_distance as mnist_dist
import experiments.distance.toy_admos_distance as toy_dist

def runMotivation():
    
    print(">>> Motivational Example (Pade-Jacobi Network)")
    
    compareAbstractions("./networks/pade_jacobi/pade_jacobi_64_64.onnx",
                        "./experiments/motivation/pade_jacobi_c2_",
                        10,
                        0.1,
                        2.0)

def runCompressionAcasXu():
    
    print(">>> Compression Experiment (AcasXu Network)")
    
    CompressAllAcasXu("./networks/acas_xu/",
                      "./datasets/acas_xu_inputs/",
                      "./experiments/compression/")

def runCompressionMnistFc():
    
    print(">>> Compression Experiment (MnistFc Network)")
    
    CompressAllMnistFc("./networks/mnist_fc/",
                       "./datasets/mnist_inputs/",
                       "./experiments/compression/")

def runCompressionToyAdmos():
    
    print(">>> Compression Experiment (ToyAdmos Network)")
    
    CompressAllToyAdmos("./networks/toy_admos/",
                        "./datasets/toy_admos_inputs/",
                        "./experiments/compression/")

def runDistanceAcasXu():
    
    print(">>> Distance Experiment (AcasXu Network)")
    
    CentroidPath = "datasets/acas_xu_inputs/acas_xu_input_0.csv"
    Centroid = np.ravel(np.loadtxt(CentroidPath, delimiter=","))

    acas_dist.RunComparison("networks/acas_xu/ACASXU_run2a_1_1_batch_2000.onnx",
                            "experiments/distance/acas_xu_",
                            Centroid,
                            2.0 ** np.arange(-20, 10, 0.25),
                            10000)

def runDistanceMnistFc():
    
    print(">>> Distance Experiment (MnistFc Network)")
    
    CentroidPath = "datasets/mnist_inputs/mnist_input_0.csv"
    Centroid = np.ravel(np.loadtxt(CentroidPath, delimiter=","))

    mnist_dist.RunComparison("networks/mnist_fc/mnist-net_256x6.onnx",
                              "experiments/distance/mnist_fc_",
                              Centroid,
                              2.0 ** np.arange(-20, 10, 0.25),
                              10000)

def runDistanceToyAdmos():
    
    print(">>> Distance Experiment (ToyAdmos Network)")
    
    CentroidPath = "datasets/toy_admos_inputs/toy_admos_input_0.csv"
    Centroid = np.ravel(np.loadtxt(CentroidPath, delimiter=","))

    toy_dist.RunComparison("networks/toy_admos/ad01_fp32.onnx",
                           "experiments/distance/toy_admos_",
                           Centroid,
                           2.0 ** np.arange(-20, 10, 0.25),
                           10000)

if __name__ == "__main__":
    
    runMotivation()
    
    runCompressionAcasXu()
    runCompressionMnistFc()
    runCompressionToyAdmos()
    
    runDistanceAcasXu()
    runDistanceMnistFc()
    runDistanceToyAdmos()
    
    print(">>> Done!")