import LeNet5
import numpy as np
from numpy import genfromtxt

"""
this main trains le-net-5 with the training set from file


"""


if __name__ == '__main__':
    # max number of epochs
    epochs = 100000
    for learning_rate in [0.1, 0.01, 0.001, 0.0001]:
        for train_size_per_number in [5, 10, 100]:
            for run_id in range(1, 31):
                train_images = np.ndarray(shape=(train_size_per_number*10, 28, 28))
                train_labels = np.ndarray(shape=(train_size_per_number*10), dtype="int")
                countImages = 0
                for number in range(0, 10):
                    my_data = genfromtxt('../training-dataset-natural/MNIST-'+str(train_size_per_number)+'/tr' + str(train_size_per_number)+"_number_"+str(number)+"_train_"+str(run_id)+".csv")
                    #   I only care about the last column
                    my_data = my_data[:, -1].copy()
                    countPixel = 0
                    for k in range(0, train_size_per_number):
                        number_matrix = np.ndarray(shape=(28, 28), dtype="int")
                        for i in range(0, 28):
                            for j in range(0, 28):
                                number_matrix[27-j,i] = my_data[countPixel]
                                countPixel = countPixel + 1
                        # print check if ok
                        #number_matrix = number_matrix.astype('float32')
                        #cv2.imshow(str(number), number_matrix)
                        #cv2.waitKey(0)
                        train_images[countImages] = number_matrix.copy()
                        train_labels[countImages] = number
                        countImages = countImages + 1
                # now I loaded all number for a given run
                model = LeNet5.build_lenet(learning_rate)
                val_images, val_labels = LeNet5.get_random_validate_set(100, run_id)
                model = LeNet5.train_lenet(model, train_images, train_labels, epochs, 32, val_images, val_labels)
                fileName = '../models/learningRate_' + str(learning_rate) + '_maxEpochs_' + str(epochs)+ '_trainingSizePerNumber_'+str(train_size_per_number)+'_runId_' + str(run_id) + ".h5"
                model.save(fileName)