import os, sys

import numpy as np
from numpy.random import seed, rand

import copy

from keras.preprocessing import image
from keras.applications import resnet50, vgg16, mobilenetv2, inception_v3, nasnet, densenet, inception_resnet_v2
from keras.layers import Reshape, Multiply
from keras.models import Model
from keras import backend as K
from keras.engine.topology import Layer

import tensorflow as tf

def setup_texture(arch='vgg',
        layer='block3_conv3', 
        cbp_size = 1024, 
        cbp_dir = '/data/networks/cbp',
        img_size = 256,
        batch_mode=False,
        batch_size=0,
        verbose=False):

    archs = {
        'mobilenet': mobilenetv2.MobileNetV2,
        'vgg': vgg16.VGG16,
        'resnet50': resnet50.ResNet50,
        'inceptionv3': inception_v3.InceptionV3,
        'nasnet': nasnet.NASNetLarge,
        'densenet': densenet.DenseNet201,
        'inceptionresnetv2': inception_resnet_v2.InceptionResNetV2,
    }
    prep = {
        'mobilenet': mobilenetv2,
        'vgg': vgg16,
        'resnet50': resnet50,
        'inceptionv3': inception_v3,
        'nasnet': nasnet,
        'densenet': densenet,
        'inceptionresnetv2': inception_resnet_v2,
    }

    cnn = archs[arch](include_top = False,
                        weights = "imagenet",
                        input_shape = (img_size, img_size, 3))
    
    # preprocess function
    preprocess = prep[arch].preprocess_input

    # define model
    output = cnn.get_layer(layer).output

    s1 = int(output.shape[1])
    s2 = int(output.shape[2])
    in_dim = int(output.shape[3])
    out_dim  = cbp_size

    output= Reshape((s1*s2, in_dim), name="reshape")(output)
    if batch_mode:
        output = CBP_batch(in_dim = in_dim,
                    out_dim = out_dim, 
                    random_matrix_path = cbp_dir,
                    batch_size=batch_size,
                    name = 'cbp')(output)
        
    else:   
        output = CBP(in_dim = in_dim,
                    out_dim = out_dim, 
                    random_matrix_path = cbp_dir,
                    name = 'cbp')(output)

    if verbose: 
        display_config(arch, layer, in_dim, out_dim)

    model = Model(inputs = cnn.input, outputs = output)

    cbp_func = K.function([model.input],
                        [model.output])
    
    
    return (preprocess, cbp_func)

class CBP_batch(Layer):
    def __init__(self, 
                 in_dim,
                 out_dim,
                 batch_size,
                 random_matrix_path,
                 **kwargs):
        self.in_dim = in_dim
        self.out_dim = out_dim 
        self.batch_size = batch_size
        self.random_matrix_path = random_matrix_path
        self._create_randmat()
        
        super(CBP, self).__init__(**kwargs)

    def _create_randmat(self):
        randmat_file = str(self.in_dim) + 'to' + str(self.out_dim) + '.npz'
        randmat_file = os.path.join(self.random_matrix_path, randmat_file) 
        if os.path.exists(randmat_file):
            R_matrix = np.load(randmat_file)
            self.W1 = R_matrix['W1'].astype(np.float32)
            self.W2 = R_matrix['W2'].astype(np.float32)
        else:
            seed(seed=128)
            raw_1 = rand(self.out_dim, self.in_dim)
            R_matrix_1 = (np.floor(raw_1*2)*2) -1
            R_matrix_1 = np.expand_dims(R_matrix_1, axis=0)
            self.W1 = R_matrix_1.astype(np.float32)
            
            seed(seed=1997)
            raw_2 = rand(self.out_dim, self.in_dim)
            R_matrix_2 = (np.floor(raw_2*2)*2) -1
            R_matrix_2 = np.expand_dims(R_matrix_2, axis=0)
            self.W2 = R_matrix_2.astype(np.float32)            
            
            np.savez(randmat_file, W1=R_matrix_1, W2=R_matrix_2)
        
        self.W1 = np.array([self.W1]*self.batch_size)
        self.W2 = np.array([self.W2]*self.batch_size)
        
        self.W1 = tf.convert_to_tensor(self.W1)
        self.W2 = tf.convert_to_tensor(self.W2)
        
                
        
    def call(self, inputs):

        Y = K.permute_dimensions(inputs,(0,2,1))
        Y = K.batch_dot(self.W1,Y)
        Y1 = K.permute_dimensions(inputs,(0,2,1))
        Y1 = K.batch_dot(self.W2,Y1)

        Y2=Multiply()([Y,Y1])

        Y2 = K.mean(Y2,axis=2)

        Y2=K.sign(Y2)*K.sqrt(K.abs(Y2))

        Y2=K.l2_normalize(Y2,axis=1)

        Y2 = K.flatten(Y2)

        return Y2
    
    def compute_output_shape(self, input_shape):
        return(input_shape[0],input_shape[2],input_shape[2])
    
    def get_config(self):
        base_config = super(CBP, self).get_config()
        return dict(list(base_config.items()))
    
class CBP(Layer):
    def __init__(self, 
                 in_dim,
                 out_dim,
                 random_matrix_path,
                 **kwargs):
        self.in_dim = in_dim
        self.out_dim = out_dim 
        self.random_matrix_path = random_matrix_path
        self._create_randmat()
        
        super(CBP, self).__init__(**kwargs)

    def _create_randmat(self):
        randmat_file = str(self.in_dim) + 'to' + str(self.out_dim) + '.npz'
        randmat_file = os.path.join(self.random_matrix_path, randmat_file) 
        if os.path.exists(randmat_file):
            R_matrix = np.load(randmat_file)
            self.W1 = R_matrix['W1'].astype(np.float32)
            self.W2 = R_matrix['W2'].astype(np.float32)
        else:
            seed(seed=128)
            raw_1 = rand(self.out_dim, self.in_dim)
            R_matrix_1 = (np.floor(raw_1*2)*2) -1
            R_matrix_1 = np.expand_dims(R_matrix_1, axis=0)
            self.W1 = R_matrix_1.astype(np.float32)
            
            seed(seed=1997)
            raw_2 = rand(self.out_dim, self.in_dim)
            R_matrix_2 = (np.floor(raw_2*2)*2) -1
            R_matrix_2 = np.expand_dims(R_matrix_2, axis=0)
            self.W2 = R_matrix_2.astype(np.float32)            
            
            np.savez(randmat_file, W1=R_matrix_1, W2=R_matrix_2)
            
        self.W1 = tf.convert_to_tensor(self.W1)
        self.W2 = tf.convert_to_tensor(self.W2)
                
        
    def call(self, inputs):

        Y = K.permute_dimensions(inputs,(0,2,1))
        Y = K.batch_dot(self.W1,Y)
        Y1 = K.permute_dimensions(inputs,(0,2,1))
        Y1 = K.batch_dot(self.W2,Y1)

        Y2=Multiply()([Y,Y1])

        Y2 = K.mean(Y2,axis=2)

        Y2=K.sign(Y2)*K.sqrt(K.abs(Y2))

        Y2=K.l2_normalize(Y2)

        Y2 = K.flatten(Y2)

        return Y2
    
    def compute_output_shape(self, input_shape):
        return(input_shape[0],input_shape[2],input_shape[2])
    
    def get_config(self):
        base_config = super(CBP, self).get_config()
        return dict(list(base_config.items()))

def calc_features_file(img_path,  prep, cnn, img_size=256, is_rot = False, angle = 90):
    img = image.load_img(img_path, target_size = (img_size, img_size))
    x = image.img_to_array(img)
    if is_rot:
        x = image.apply_affine_transform(x, theta = angle)
    x = np.expand_dims(x, axis=0)
    x = prep(x)
    return (cnn([x])[0])

def calc_features(img, prep, cnn, img_size = 256):
    img2 = copy.deepcopy(img)
    x = image.img_to_array(img2)
    x = np.expand_dims(x, axis=0)
    x = prep(x)
    return (cnn([x])[0])

def calc_similarity(feat1, feat2):
    similarity = 1 - 0.5 * np.sum((feat1 - feat2) ** 2)
    return (similarity)

