import numpy as np
import mnist
import random
from keras.utils import np_utils
import tensorflow as tf
from keras.callbacks import EarlyStopping


def build_lenet(learning_rate):
    input_shape = (28, 28, 1)
    # sequentail API
    model = tf.keras.Sequential()
    # convolutional layer 1
    model.add(tf.keras.layers.Conv2D(filters=6,
                                     kernel_size=(5, 5),
                                     strides=(1, 1),
                                     activation='tanh',
                                     input_shape=input_shape))
    # average pooling layer 1
    model.add(tf.keras.layers.AveragePooling2D(pool_size=(2, 2),
                                               strides=(2, 2)))
    # convolutional layer 2
    model.add(tf.keras.layers.Conv2D(filters=16,
                                     kernel_size=(5, 5),
                                     strides=(1, 1),
                                     activation='tanh'))
    # average pooling layer 2
    model.add(tf.keras.layers.AveragePooling2D(pool_size=(2, 2),
                                               strides=(2, 2)))
    model.add(tf.keras.layers.Flatten())
    # fully connected
    model.add(tf.keras.layers.Dense(units=120,
                                    activation='tanh'))
    model.add(tf.keras.layers.Flatten())
    # fully connected
    model.add(tf.keras.layers.Dense(units=84, activation='tanh'))
    # output layer
    model.add(tf.keras.layers.Dense(units=10, activation='softmax'))

    model.compile(loss='categorical_crossentropy',
                  optimizer=tf.keras.optimizers.SGD(lr=learning_rate, momentum=0.0, decay=0.0),
                  metrics=['accuracy'])

    return model

def train_lenet(model, train_images, train_labels, epochs, batch_size_input, x_val, y_val):
    #128 batchs ize is original
    x_train = train_images.astype('float32')
    # Normalize value to [0, 1]
    x_train /= 255
    # Transform lables to one-hot encoding
    y_train = np_utils.to_categorical(train_labels, 10)
    # Reshape the dataset into 4D array
    x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)

    es = EarlyStopping(monitor='val_loss',
                           min_delta=0.0001,
                           patience=30,
                           verbose=0, mode='min',
                           restore_best_weights=True)

    model.fit(x_train, y_train,
                            epochs=epochs,
                            batch_size=batch_size_input,
                            verbose=0,
                            validation_data=(x_val, y_val),
                            callbacks=[es],
                            use_multiprocessing=True)
    return model

def evaluate_lenet(model):
    test_images, test_labels = mnist.test_images(), mnist.test_labels()
    x_test = test_images.astype('float32')
    x_test /= 255
    y_test = np_utils.to_categorical(test_labels, 10)
    x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)
    loss, acc = model.evaluate(x_test, y_test, verbose=0)
    return acc

def get_random_validate_set(size_validation_per_number, seed):
    all_train_images, all_train_labels = mnist.train_images(), mnist.train_labels()
    random.seed(seed)
    for num in range(10):
        train_filter = np.where(all_train_labels == num)
        filtered = all_train_images[train_filter]
        random.shuffle(filtered)
        if num == 0:
            val_images = filtered[:size_validation_per_number]
            val_labels = all_train_labels[train_filter][:size_validation_per_number]
        else:
            val_images = np.append(val_images, filtered[:size_validation_per_number],
                                   axis=0)
            val_labels = np.append(val_labels,
                                   all_train_labels[train_filter][:size_validation_per_number],
                                   axis=0)
    x_val = val_images.astype('float32')
    x_val /= 255
    y_val = np_utils.to_categorical(val_labels, 10)
    x_val = x_val.reshape(x_val.shape[0], 28, 28, 1)
    return x_val, y_val

