from pathlib import Path

from tensorflow import keras
from keras.layers import Dense, Input, Normalization
from keras.models import Model
from keras.utils import to_categorical
import numpy as np
from scipy import io as sio

from sklearn.model_selection import train_test_split

# fetch mnist dataset
config = sio.loadmat(
    Path("awg_mnist_010324.mat"), squeeze_me=True
)
X = config["features"]
X = X.reshape((X.shape[0], -1))

y = config["label"]
classes = np.unique(y).size
y = to_categorical(y, num_classes=classes)

X_train, X_test, y_train, y_test = train_test_split(
    X, y, train_size=60_000, shuffle=False
)


# scaler = StandardScaler()  #  MinMaxScaler(feature_range=(0, 1))
# scaler.fit(X_train)
# X_train = scaler.transform(X_train)
# X_test = scaler.transform(X_test)


def pipeline(cX_train, cX_test, cy_train, cy_test, L2):
    shape = cX_train.shape[1]
    input_ = Input(shape=(shape,))

    layer = Normalization()
    layer.adapt(cX_train)
    norm_layer = layer(input_)

    output = Dense(
        classes,
        kernel_regularizer=keras.regularizers.l2(L2),
    )(norm_layer)
    model = Model(inputs=input_, outputs=output)

    # Compile the model
    model.compile(
        optimizer=keras.optimizers.Adam(),
        loss=keras.losses.CategoricalCrossentropy(
            from_logits=True,
        ),
        metrics=["accuracy"],
    )

    # Train the model
    model.fit(
        cX_train,
        cy_train,
        epochs=20,
        batch_size=64,
    )

    # Evaluate the model
    train_loss, train_accuracy = model.evaluate(cX_train, cy_train)
    test_loss, test_accuracy = model.evaluate(cX_test, cy_test)
    return train_loss, train_accuracy, test_loss, test_accuracy


A = train_test_split(X_train, y_train, train_size=0.8, shuffle=False)

L2s = np.asarray([1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9, 1e-10])
test_accuracy = np.empty_like(L2s)

for i, L2 in enumerate(L2s):
    print(L2)
    _, _, _, test_accuracy[i] = pipeline(*A, L2)
    print(test_accuracy[i])

print(pipeline(X_train, X_test, y_train, y_test, L2s[np.argmax(test_accuracy)]))


# np.savez(
#     Path.cwd() / "results_fashion_MNIST.npz",
#     train_loss=train_loss,
#     train_accuracy=train_accuracy,
#     test_loss=test_loss,
#     test_accuracy=test_accuracy,
# )
