#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
This code accompanies the article:
Jacobs, R. A. & Xu, C. (2019). Can multisensory training aid visual
learning?: A computational investigation. Journal of Vision, in press.

As described in the article, we implemented a beta variational autoencoder
(\beta-VAE) that received both visual and haptic signals regarding the
shapes of objects. The implementation here is a slight variant of the
implementation described by Louis Tiao in his web post titled "Implementing
Variational Autoencoders in Keras: Beyond the Quickstart Tutorial".

In this code, the haptic data are the GraspIt! joint angles for
each Fribble. Recall that GraspIt! has 16 joints and that each Fribble was
grasped 24 times, meaning that there are 384 values. The dimensionality
of these values was then reduced via PCA to 200 components (accounting
for more than 99% of the variance in the haptic values). Each low
dimensional value has been normalized so that it has a mean of zero
and a variance of one. 

The visual data items were created as follows. First, there are two images
of each Fribble, the original image and a flipped (left-right) image. These
images were then presented to VGG16, and we extracted the output of the
convolution base (7 X 7 X 512 = 25088 values). Given the values of the 
convolution base for each image of each Fribble (2 imagex X 891 Fribbles),
we then did PCA to reduce the dimensionality to 200 (accounting for more
than 97% of the variance in the convolution base values). Each of the values
in the low-dimension space was then normalized to have a mean of zero and
a variance of one.

For each Fribble, there are 2 data items:
-- original image and haptic data
-- flipped image and haptic data

For each data item, the target labels include both the visual and
haptic data.

This program includes:
-- use of stochastic gradient descent (SGD)
-- use of tanh activation function
-- use of 10-fold cross-validation
-- compute tSNE on training items only
-- compute both visual and haptic SSE on test items only
-- compute classifications on test items only
-- compute correlations on test items only
-- use of wider range of beta values
"""

############################################

import numpy as np
from keras import backend as K
from keras.layers import Input, Dense, Lambda, Layer, Add, Multiply
from keras.layers import concatenate
from keras.models import Model
import os
from matplotlib import pyplot as plt
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.svm import LinearSVC
from skmultilearn.problem_transform import BinaryRelevance

original_dim = 200
intermediate_dim = original_dim * 2
latent_dim = intermediate_dim * 2
batch_size = 64
epochs = 2000
epsilon_std = 1.0

numFolds = 10

betaValues = np.array([0.01, 1.0, 2.5, 5.0, 10.0, 20.0])

original = 0
flipped = 1

############################################

def read_input_label_dataitems():
    """Load visual and haptic inputs and labels
    """

    cwd = os.getcwd()

    infile = cwd + '/inputVisualDataItems_numpy.npy'
    visual_input = np.load(infile)
    infile = cwd + '/inputHapticDataItems_numpy.npy'
    haptic_input = np.load(infile)
    
    infile = cwd + '/labelVisualDataItems_numpy.npy'
    visual_label = np.load(infile)
    infile = cwd + '/labelHapticDataItems_numpy.npy'
    haptic_label = np.load(infile)

    return visual_input, haptic_input, visual_label, haptic_label


def read_cross_validation():
    cwd = os.getcwd()
    infile = cwd + '/../CrossValidation_VH_Data/CrossValidation_numpy.npy'
    CV = np.load(infile)
    CV = np.rint(CV)
    CV = CV.astype(int)
    return CV


def make_CV_train_test(fold):
    train_indices = np.where(CV[fold] == 1)
    train_indices = np.asarray(train_indices)[0, :]
    v_input_train = visual_input[train_indices, :]
    h_input_train = haptic_input[train_indices, :]
    v_label_train = visual_label[train_indices, :]
    h_label_train = haptic_label[train_indices, :]
    
    test_indices = np.where(CV[fold] == -1)
    test_indices = np.asarray(test_indices)[0, :]
    v_input_test = visual_input[test_indices, :]
    h_input_test = haptic_input[test_indices, :]
    v_label_test = visual_label[test_indices, :]
    h_label_test = haptic_label[test_indices, :]
    
    orig_flip_train = np.zeros(len(train_indices))
    orig_flip_test = np.zeros(len(test_indices))

    index_train = index_test = 0
    for item in range(visual_input.shape[0]):
        if item in train_indices:
            if item % 2 == 0:
                orig_flip_train[index_train] = original
            else:
                orig_flip_train[index_train] = flipped
            
            index_train += 1

        if item in test_indices:
            if item % 2 == 0:
                orig_flip_test[index_test] = original
            else:
                orig_flip_test[index_test] = flipped
            
            index_test += 1

    if index_train != len(train_indices):
        print('Error: index_train not equal to number of train indices')
        
    if index_test != len(test_indices):
        print('Error: index_test not equal to number of test indices')
    
    cwd = os.getcwd()
    infile = cwd + '/labelFamily_numpy.npy'
    labelFamily = np.load(infile)
    labelFamily = np.rint(labelFamily)
    labelFamily = labelFamily.astype(int)
    labelFamily_train = labelFamily[train_indices]
    labelFamily_test = labelFamily[test_indices]

    infile = cwd + '/labelSpecies_numpy.npy'
    labelSpecies = np.load(infile)
    labelSpecies = np.rint(labelSpecies)
    labelSpecies = labelSpecies.astype(int)
    labelSpecies_test = labelSpecies[test_indices]

    infile = cwd + '/labelIdentity_numpy.npy'
    labelIdentity = np.load(infile)
    labelIdentity = np.rint(labelIdentity)
    labelIdentity = labelIdentity.astype(int)
    labelIdentity_test = labelIdentity[test_indices]

    infile = cwd + '/labelParts_numpy.npy'
    labelParts = np.load(infile)
    labelParts = np.rint(labelParts)
    labelParts = labelParts.astype(int)
    labelParts_test = labelParts[test_indices]

    return v_input_train, h_input_train, v_label_train, h_label_train, \
           v_input_test, h_input_test, v_label_test, h_label_test, \
           orig_flip_train, orig_flip_test, labelFamily_train, \
           labelFamily_test, labelSpecies_test, \
           labelIdentity_test, labelParts_test, test_indices


def nll(y_true, y_pred):
    """ Negative log likelihood (Bernoulli). """

    # keras.losses.mean_squared_error gives the mean
    # over the last axis. we require the sum
    return K.sum(K.square(y_true - y_pred), axis=-1)


class KLDivergenceLayer(Layer):
    """ Identity transform layer that adds KL divergence
    to the final model loss.
    """

    def __init__(self, *args, **kwargs):
        self.is_placeholder = True
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)

    def call(self, inputs):
        mu, log_var = inputs

        kl_batch = - .5 * K.sum(1 + log_var -
                                K.square(mu) -
                                K.exp(log_var), axis=-1)

        self.add_loss(beta * K.mean(kl_batch), inputs=inputs)

        return inputs


def compute_SSE(fold, beta):
    """SSE for visual error only and haptic error only on test items.
    """
    global v_sse, h_sse

    indexValue = (fold * betaValues.shape[0]) + betaIndex
    v_sse[indexValue, 0] = fold
    v_sse[indexValue, 1] = beta
    h_sse[indexValue, 0] = fold
    h_sse[indexValue, 1] = beta
    
    v_sseOrig_VH = v_sseFlip_VH = 0.0
    h_sseOrig_VH = h_sseFlip_VH = 0.0

    for item in range(v_input_test.shape[0]):
        v_pattern = v_input_test[item, :]
        h_pattern = h_input_test[item, :]
        v_output, h_output = vae.predict([[v_pattern], [h_pattern]])

        v_output = v_output[0, :]
        v_target = v_label_test[item, :]
        v_squared_error = np.sum(np.square(v_target - v_output))

        h_output = h_output[0, :]
        h_target = h_label_test[item, :]
        h_squared_error = np.sum(np.square(h_target - h_output))
        
        if orig_flip_test[item] == original:
            v_sseOrig_VH += v_squared_error
            h_sseOrig_VH += h_squared_error
        else:
            v_sseFlip_VH += v_squared_error
            h_sseFlip_VH += h_squared_error
    
    v_sse[indexValue, 2] = v_sseOrig_VH
    v_sse[indexValue, 3] = v_sseFlip_VH
    h_sse[indexValue, 2] = h_sseOrig_VH
    h_sse[indexValue, 3] = h_sseFlip_VH
    
    # Next, test network with visual input only (set haptic input to zero).
    
    v_sseOrig_V = v_sseFlip_V = 0.0
    h_sseOrig_V = h_sseFlip_V = 0.0
    
    for item in range(v_input_test.shape[0]):
        v_pattern = v_input_test[item, :]
        h_pattern = np.zeros(original_dim)
        v_output, h_output = vae.predict([[v_pattern], [h_pattern]])

        v_output = v_output[0, :]
        v_target = v_label_test[item, :]
        v_squared_error = np.sum(np.square(v_target - v_output))

        h_output = h_output[0, :]
        h_target = h_label_test[item, :]
        h_squared_error = np.sum(np.square(h_target - h_output))
        
        if orig_flip_test[item] == original:
            v_sseOrig_V += v_squared_error
            h_sseOrig_V += h_squared_error
        else:
            v_sseFlip_V += v_squared_error
            h_sseFlip_V += h_squared_error
    
    v_sse[indexValue, 4] = v_sseOrig_V
    v_sse[indexValue, 5] = v_sseFlip_V
    h_sse[indexValue, 4] = h_sseOrig_V
    h_sse[indexValue, 5] = h_sseFlip_V
    
    cwd = os.getcwd()
    outfile = cwd + '/result_v_sse_numpy'
    np.save(outfile, v_sse)
    outfile = cwd + '/result_h_sse_numpy'
    np.save(outfile, h_sse)
        
    return


def latentValues(fold, beta):
    """First compute the latent values with both visual and haptic input.
    Then compute the latent values with visual input only.
    All of this is done on the test items only.
    """
    z_VH_test = np.zeros((v_input_test.shape[0], latent_dim))
    for item in range(v_input_test.shape[0]):
        v_pattern = v_input_test[item, :]
        h_pattern = h_input_test[item, :]
        latent = encoder.predict([[v_pattern], [h_pattern]])
        z_VH_test[item, :] = latent[0, :]

    cwd = os.getcwd()
    outfile = cwd + '/result_latentValues_VH_test_%d_%4.2f_numpy' % (fold, beta)
    np.save(outfile, z_VH_test)

    # Next, test network with visual input only (set haptic input to zero).

    z_V_test = np.zeros((v_input_test.shape[0], latent_dim))
    for item in range(v_input_test.shape[0]):
        v_pattern = v_input_test[item, :]
        h_pattern = np.zeros(original_dim)
        latent = encoder.predict([[v_pattern], [h_pattern]])
        z_V_test[item, :] = latent[0, :]

    cwd = os.getcwd()
    outfile = cwd + '/result_latentValues_V_test_%d_%4.2f_numpy' % (fold, beta)
    np.save(outfile, z_V_test)
    
    z_V_train = np.zeros((v_input_train.shape[0], latent_dim))
    for item in range(v_input_train.shape[0]):
        v_pattern = v_input_train[item, :]
        h_pattern = np.zeros(original_dim)
        latent = encoder.predict([[v_pattern], [h_pattern]])
        z_V_train[item, :] = latent[0, :]

    cwd = os.getcwd()
    outfile = cwd + '/result_latentValues_V_train_%d_%4.2f_numpy' % (fold, beta)
    np.save(outfile, z_V_train)

    z_V_all = np.zeros((visual_input.shape[0], latent_dim))
    for item in range(visual_input.shape[0]):
        v_pattern = visual_input[item, :]
        h_pattern = np.zeros(original_dim)
        latent = encoder.predict([[v_pattern], [h_pattern]])
        z_V_all[item, :] = latent[0, :]

    cwd = os.getcwd()
    outfile = cwd + '/result_latentValues_V_all_%d_%4.2f_numpy' % (fold, beta)
    np.save(outfile, z_V_all)

    return z_VH_test, z_V_test, z_V_train, z_V_all


def tSNE(test, beta):
    """Perform t-SNE on the latent variable values from the training
    trials with visual input only (set haptic input to zero).
    Write two graphs, one where dots are colored by family and the
    other where dots are colored by original vs flipped.
    """
    print('--> Entering TSNE: Please wait...')
    z_embedded = TSNE(n_components=2).fit_transform(z_V_train)

    color = np.zeros(v_input_train.shape[0])
    for item in range(v_input_train.shape[0]):
        if labelFamily_train[item] == 0:
            color[item] = 0.0
        elif labelFamily_train[item] == 1:
            color[item] = 0.5
        else:
            color[item] = 1.0
    
    plt.figure()
    plt.scatter(z_embedded[:, 0], z_embedded[:, 1], c=color)
    plt.xlim(-90.0, 90.0)
    plt.ylim(-90.0, 90.0)
    plt.xticks([-90.0, 90.0])
    plt.yticks([-90.0, 90.0])
    plt.axes().set_aspect(1.0)
    title = 'Family (beta = %4.2f)' % (beta)
    plt.title(title)
    outfile = cwd + '/result_TSNE_Family_%d_%4.2f.pdf' % (fold, beta)
    plt.savefig(outfile, bbox_inches='tight')
    plt.show()
    
    color = np.zeros(v_input_train.shape[0])
    for item in range(v_input_train.shape[0]):
        if orig_flip_train[item] == original:
            color[item] = 0.0
        else:
            color[item] = 1.0
    
    plt.figure()
    plt.scatter(z_embedded[:, 0], z_embedded[:, 1], c=color)
    plt.xlim(-90.0, 90.0)
    plt.ylim(-90.0, 90.0)
    plt.xticks([-90.0, 90.0])
    plt.yticks([-90.0, 90.0])
    plt.axes().set_aspect(1.0)
    title = 'Orientation (beta = %4.2f)' % (beta)
    plt.title(title)
    outfile = cwd + '/result_TSNE_OrigFlip_%d_%4.2f.pdf' % (fold, beta)
    plt.savefig(outfile, bbox_inches='tight')
    plt.show()
    
    return
    

def latentClassify(fold, beta):
    """Perform classification using the latent variable values from the test
    trials with visual input only (haptic input is set to zero).
    """
    global classScores

    indexValue = (fold * betaValues.shape[0]) + betaIndex
    classScores[indexValue, 0] = fold
    classScores[indexValue, 1] = beta
    
    # With 40 components, can explain about XX% of the variance in the data
    pca = PCA(n_components=40)
    z_visual_new = pca.fit_transform(z_V_test)
    
    clf = LinearSVC(random_state=0)
    clf.fit(z_visual_new, labelFamily_test)
    classScores[indexValue, 2] = clf.score(z_visual_new, labelFamily_test)
    print('Classify object family: %6.3f' % (classScores[indexValue, 2]))
    
    # Classify object species
    clf = LinearSVC(random_state=0)
    clf.fit(z_visual_new, labelSpecies_test)
    classScores[indexValue, 3] = clf.score(z_visual_new, labelSpecies_test)
    print('Classify object species: %6.3f' % (classScores[indexValue, 3]))
    
    # Classify object identity
    clf = LinearSVC(random_state=0)
    clf.fit(z_visual_new, labelIdentity_test)
    classScores[indexValue, 4] = clf.score(z_visual_new, labelIdentity_test)
    print('Classify object identity: %6.3f' % (classScores[indexValue, 4]))
    
    # Classify ojbect parts (multilabel classification)
    # The problem is that there may be some part that NEVER appears in
    # the test set. To address this problem, first detect if there is
    # a part that does not appear in the test set. If so, then delete
    # that part from the test set.
    tempLabel = np.copy(labelParts_test)
    tSum = np.sum(tempLabel, axis=0)
    tIndex = np.where(tSum == 0)
    tIndex = np.asarray(tIndex)[0, :]
    for item in range(len(tIndex)):
        tempLabel = np.delete(tempLabel, tIndex[item], axis=1)

    clf = BinaryRelevance(LinearSVC())
    clf.fit(z_visual_new, tempLabel)
    classScores[indexValue, 5] = clf.score(z_visual_new, tempLabel)
    print('Classify object parts: %6.3f' % (classScores[indexValue, 5]))
    
    outfile = cwd + '/result_classScores_numpy'
    np.save(outfile, classScores)
    
    return


def crossModalCorrelation(fold, beta):
    """Correlate latent variable values (on test trials with visual
    input only) on original and flipped orientations.
    """
    global corrScores
    
    indexValue = (fold * betaValues.shape[0]) + betaIndex
    corrScores[indexValue, 0] = fold
    corrScores[indexValue, 1] = beta
    
    # Correlate latent values for test items with all items
    
    print('Calculating crossmodal correlation matrix. Please wait...')
    corrMatrix = np.zeros((z_V_test.shape[0], z_V_all.shape[0]))
    for test_item in range(z_V_test.shape[0]):
        for all_item in range(z_V_all.shape[0]):
            corrMatrix[test_item, all_item] = \
                np.corrcoef(z_V_test[test_item, :],
                            z_V_all[all_item, :])[0, 1]
    
    # Obviously, the latent values for a test item will be most
    # correlated with itself. Let's set these values to zero.
    for test_item in range(z_V_test.shape[0]):
        index = test_indices[test_item]
        corrMatrix[test_item, index] = 0.0

    # For a given test item, find the "all" items with the top 5
    # highest correlations. Then check whether these item include
    # the item that depicts the same object, albeit from a
    # different orientation.
    correct = 0
    for test_item in range(z_V_test.shape[0]):
        test_index = test_indices[test_item]
        if test_index % 2 == 0:
            same_object_index = test_index + 1
        else:
            same_object_index = test_index - 1

        values = np.partition(corrMatrix[test_item, :], -5)[-5:]
        min_value = np.amin(values)
        if corrMatrix[test_item, same_object_index] >= min_value:
            correct += 1

    temp = correct / z_V_test.shape[0]
    corrScores[indexValue, 2] = temp
    print('Proportion where visual test item matches same object: %5.2f'
          % (temp))

    cwd = os.getcwd()
    outfile = cwd + '/result_corrScores_numpy'
    np.save(outfile, corrScores)


############################################

visual_input, haptic_input, visual_label, haptic_label = \
    read_input_label_dataitems()

CV = read_cross_validation()

v_sse = np.zeros((numFolds * betaValues.shape[0], 6))
h_sse = np.zeros((numFolds * betaValues.shape[0], 6))
classScores = np.zeros((numFolds * betaValues.shape[0], 6))
corrScores = np.zeros((numFolds * betaValues.shape[0], 3))

for fold in range(numFolds):
    v_input_train, h_input_train, v_label_train, h_label_train, \
    v_input_test, h_input_test, v_label_test, h_label_test, \
    orig_flip_train, orig_flip_test, labelFamily_train, \
    labelFamily_test, labelSpecies_test, \
    labelIdentity_test, labelParts_test, test_indices = \
    make_CV_train_test(fold)

    for betaIndex, beta in enumerate(betaValues):
        # Main code that defines the VAE
    
        # Encoder network
        
        encoder_x_visual = Input(shape=(original_dim,), name='x_visual')
        encoder_h_visual = Dense(intermediate_dim, activation='tanh')\
            (encoder_x_visual)
        
        encoder_x_haptic = Input(shape=(original_dim,), name='x_haptic')
        encoder_h_haptic = Dense(intermediate_dim, activation='tanh')\
            (encoder_x_haptic)
        
        concatenated = concatenate([encoder_h_visual, encoder_h_haptic])
        
        z_mu = Dense(latent_dim)(concatenated)
        z_log_var = Dense(latent_dim)(concatenated)
        
        z_mu, z_log_var = KLDivergenceLayer()([z_mu, z_log_var])
        z_sigma = Lambda(lambda t: K.exp(.5*t))(z_log_var)
        
        eps = Input(tensor=K.random_normal(stddev=epsilon_std,
                                           shape=(K.shape(encoder_x_visual)[0],
                                                  latent_dim)))
        z_eps = Multiply()([z_sigma, eps])
        z = Add()([z_mu, z_eps])
        
        # Decoder network
        
        decoder_h_visual = Dense(intermediate_dim, input_dim=latent_dim,
                                 activation='tanh')(z)
        decoder_y_visual = Dense(original_dim, name='y_visual')(decoder_h_visual)
        
        decoder_h_haptic = Dense(intermediate_dim, input_dim=latent_dim,
                                 activation='tanh')(z)
        decoder_y_haptic = Dense(original_dim, name='y_haptic')(decoder_h_haptic)
        
        # Define model, compile, and fit
        
        vae = Model(inputs=[encoder_x_visual, encoder_x_haptic, eps],
                    outputs=[decoder_y_visual, decoder_y_haptic])
    
        vae.compile(optimizer='sgd',
                    loss={'y_visual': nll, 'y_haptic': nll},
                    loss_weights={'y_visual': 1.0, 'y_haptic': 1.0})
    
        vae.fit({'x_visual': v_input_train, 'x_haptic': h_input_train},
                {'y_visual': v_label_train, 'y_haptic': h_label_train},
                shuffle=True,
                epochs=epochs,
                batch_size=batch_size)
        
        # Save VAE to file
        cwd = os.getcwd()
        outfile = cwd + '/result_VAE_%d_%4.2f.h5' % (fold, beta)
        vae.save(outfile)
        
        print('          *** Fold = %d, Beta = %4.2f ***' % (fold, beta))
    
        # Compute SSE measures
        compute_SSE(fold, beta)
        
        # Compute means of latent variables
        encoder = Model(inputs=[encoder_x_visual, encoder_x_haptic],
                        outputs=z_mu)
        z_VH_test, z_V_test, z_V_train, z_V_all = latentValues(fold, beta)
        
        # Perform tSNE on latent values.
        if fold == 0:
            tSNE(fold, beta)
    
        # Classification (using linearSVC) based on latent variables.
        latentClassify(fold, beta)
        
        # Compute crossmodal correlation matrix.
        crossModalCorrelation(fold, beta)
