%matplotlib inline
import os
import gc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
from PIL import Image, ImageFile
from urllib.request import urlopen
from django.core.validators import URLValidator
from django.core.exceptions import ValidationError
import keras
from keras import backend as K
from keras.applications.resnet50 import ResNet50
from keras.applications.resnet50 import preprocess_input
ImageFile.LOAD_TRUNCATED_IMAGES = True
def delete_model(model, clear_session=True):
'''removes model!
'''
del model
gc.collect()
if clear_session: K.clear_session()
def is_url(url):
val = URLValidator()
try:
val(url)
return True
except ValidationError as e:
return False
def plot_channels(img):
_ , ax = plt.subplots(1, 3, sharex='col', sharey='row', figsize=(24, 6))
plt.suptitle('RBG Channels of an Image', size=20)
np_img_r = np_img.copy()
np_img_r[:, :, 1] = np.zeros(shape=[img_height, img_width])
np_img_r[:, :, 2] = np.zeros(shape=[img_height, img_width])
ax[0].imshow(np_img_r)
ax[0].axis('off')
np_img_g = np_img.copy()
np_img_g[:, :, 0] = np.zeros(shape=[img_height, img_width])
np_img_g[:, :, 2] = np.zeros(shape=[img_height, img_width])
ax[1].imshow(np_img_g)
ax[1].axis('off')
np_img_b = np_img.copy()
np_img_b[:, :, 0] = np.zeros(shape=[img_height, img_width])
np_img_b[:, :, 1] = np.zeros(shape=[img_height, img_width])
ax[2].imshow(np_img_b)
ax[2].axis('off')
def read_img_url(url):
with urlopen(url) as file:
img = Image.open(file)
if img.mode != 'RGB':
img = img.convert('RGB')
return img
def read_img_file(f):
img = Image.open(f)
if img.mode != 'RGB':
img = img.convert('RGB')
return img
def read_img(f):
if is_url(f):
img = read_img_url(f)
else:
img = read_img_file(f)
return img
def resize_img_to_array(img, img_shape=(244, 244)):
img_array = np.array(
img.resize(
img_shape,
Image.ANTIALIAS
)
)
return img_array
@LeonYin
Social Media and Political Participation (SMaPP) Lab
Center for Data Science
New York University
PyData NYC ~ 2017-11-28
NN == Matrix Multiplication and Thresholding:
More here
What is going on inside?
from keras.applications.resnet50 import ResNet50
from keras.applications.resnet50 import preprocess_input
def make_resnet_conv(input_shape):
'''
Creates a ResNet50 model trained on ImageNet.
It includes no final activation function,
so model returns conv. features.
`input_shape` is a tuple of integers.
'''
model = ResNet50(input_shape=input_shape,
weights='imagenet',
include_top=False)
for layer in model.layers:
layer.trainable = False
return model
url = 'http://i.imgur.com/P0eRT3y.jpg'
img_width, img_height = 299, 299
img = read_img(url) # reads an image link into a NumPy array...
np_img = resize_img_to_array(img, img_shape=(img_width, img_height))
imshow(np_img);
np_img.shape
plot_channels(np_img)
np_img = np.expand_dims(np_img, axis=0)
np_img.shape
model = make_resnet_conv(input_shape=[img_width, img_height, 3])
X = preprocess_input(np_img.astype(np.float))
X_conv = model.predict(X)
delete_model(model)
X[0].shape, X_conv[0].shape
We can iron out the extra dimensions
X_conv_2d = X_conv[0].flatten()
X_conv_2d
train_reshape = (X_conv.shape[0], np.prod(X_conv.shape[1:]))
X_conv_2d = X_conv.reshape(train_reshape)
X_conv_2d[0].shape
We can do this in 5 lines of code!
def get_conv_feat(f, model):
'''
For any given image (file or url),
convert to NumPy Array, resize to img_width x img_height,
preprocess the values for ResNet, get the convolutional features
from ResNet, and flatten the output.
'''
img = read_img(f)
np_img = resize_img_to_array(img, img_shape=(img_width, img_height))
X = preprocess_resnet(np.expand_dims(np_img, axis=0).astype(np.float))
X_conv = model.predict(X)
X_conv_2d = X_conv[0].flatten()
return X_conv_2d
Let's look at the output of the other popular models:
Let's use the conv features we calculated for 100k images.
meta_path = '/beegfs/work/smapp/search_feats/google_conv_feats.csv'
X = pd.read_csv(meta_path)
Y = X['filename']
X_conv_2d = X[[_ for _ in X.columns if _ != 'filename']].values.astype(np.float)
X_conv_2d.shape
Training KNN on 100K images w/ 2K features takes about 2 minutes!
from sklearn.neighbors import NearestNeighbors
knn = NearestNeighbors(n_neighbors=20, n_jobs=8, algorithm='ball_tree')
knn.fit(X_conv_2d)
from sklearn.externals import joblib
knn_file = '/beegfs/work/smapp/search_feats/knn/google_conv_feats.pkl'
joblib.dump(knn, knn_file)
ls /beegfs/work/smapp/google_images_sample/*
def plot_neighbors(neighbors, Y, top=8, per_row=4):
distance = neighbors[0][0]
files = Y[neighbors[1][0]].tolist()
for i in range(len(distance[:top])):
if i % 4 == 0:
_ , ax = plt.subplots(1, per_row, sharex='col', sharey='row', figsize=(24, 6))
j = i % 4
image = read_img(files[i])
image = resize_img_to_array(image, img_shape=(img_width, img_height))
ax[j].imshow(image)
ax[j].axis('off')
ax[j].annotate(distance[i],
(0,0), (0, -32), xycoords='axes fraction',
textcoords='offset points', va='top')
def get_conv_feats(f, model):
'''
Converts an image (str of path or url) to a 2d convolutional features AKA logit.
'''
img = read_img(f)
np_img = resize_img_to_array(img, img_shape=(img_width, img_height))
X = preprocess_input(np.expand_dims(np_img, axis=0).astype(np.float))
X_conv = model.predict(X)
new_shape = (X_conv.shape[0], np.prod(X_conv.shape[1:]))
X_conv_2d = X_conv.reshape(new_shape)
return X_conv_2d
def get_neighbors(f, knn_file, reference_files, top=8):
'''
Loads the KNN file, and a pre-trained neural network.
Converts image (f) to convoultional features,
Sends conv features to KNN to find closest hits,
Plots the top images and their distances.
'''
knn = joblib.load(knn_file)
model = make_resnet_conv(input_shape=[img_width, img_height, 3])
X_conv_2d = get_conv_feats(f, model)
neighbors = knn.kneighbors(X_conv_2d, return_distance=True)
plot_neighbors(neighbors, reference_files, top=top, per_row=4)
delete_model(model)
knn_file = '/beegfs/work/smapp/search_feats/knn/google_conv_feats.pkl'
meta_path = '/beegfs/work/smapp/search_feats/google_conv_feats.csv'
Y = pd.read_csv(meta_path, usecols=['filename'], squeeze=True)
url = 'http://i.imgur.com/P0eRT3y.jpg'
read_img(url)
get_neighbors(url, knn_file, Y, top=8)
search = 'https://ak5.picdn.net/shutterstock/videos/9646475/thumb/8.jpg?i10c=img.resize(height:160)'
read_img(search)
get_neighbors(search, knn_file, Y, top=8)
search = 'https://ichef.bbci.co.uk/news/1024/socialembed/https://twitter.com/jack/status/912784057863245824~/news/business-41408798'
read_img(search)
get_neighbors(search, knn_file, Y, top=8)
w - yinleon.github.io
t - @leonyin