卷機神經網路的視覺化(視覺化中間啟用)
阿新 • • 發佈:2019-02-20
對於中間啟用的視覺化,我們使用之前在貓狗分類中從頭開始訓練的小型卷積神經網路。
from keras.models import load_model
model = load_model('cats_and_dogs_small_2.h5')
model.summary()
Layer (type) Output Shape Param # ================================================================= conv2d_5 (Conv2D) (None, 148, 148, 32) 896 _________________________________________________________________ max_pooling2d_5 (MaxPooling2 (None, 74, 74, 32) 0 _________________________________________________________________ conv2d_6 (Conv2D) (None, 72, 72, 64) 18496 _________________________________________________________________ max_pooling2d_6 (MaxPooling2 (None, 36, 36, 64) 0 _________________________________________________________________ conv2d_7 (Conv2D) (None, 34, 34, 128) 73856 _________________________________________________________________ max_pooling2d_7 (MaxPooling2 (None, 17, 17, 128) 0 _________________________________________________________________ conv2d_8 (Conv2D) (None, 15, 15, 128) 147584 _________________________________________________________________ max_pooling2d_8 (MaxPooling2 (None, 7, 7, 128) 0 _________________________________________________________________ flatten_2 (Flatten) (None, 6272) 0 _________________________________________________________________ dropout_1 (Dropout) (None, 6272) 0 _________________________________________________________________ dense_3 (Dense) (None, 512) 3211776 _________________________________________________________________ dense_4 (Dense) (None, 1) 513 ================================================================= Total params: 3,453,121 Trainable params: 3,453,121 Non-trainable params: 0
接下面,輸入一張不屬於網路的貓的影象
img_path = '/Users/fchollet/Downloads/cats_and_dogs_small/test/cats/cat.1700.jpg'
from keras.preprocessing import image # 將影象處理成為一個4D張量
import numpy as np
img = image.load_img(img_path, target_size=(150, 150))
img_tensor = image.img_to_array(img)
img_tensor = np.expand_dims(img_tensor, axis= 0)
img_tensor /= 255.
print(img_tensor.shape)
(1, 150, 150, 3)
顯示測試影象
import matplotlib.pyplot as plt
plt.imshow(img_tensor[0])
plt.show()
為了提取想要檢視的特徵圖,我們需要建立一個Keras模型,以影象批量作為輸入,並輸出所有卷積層和池化層的啟用。為此,我們需要使用Keras的Model類。模型例項化需要兩個引數:一個輸入張量(或輸入張量的列表)和一個輸出張量(或輸出張量的列表)。
from keras import models
layer_outputs = [layer.output for layer in model.layers[:8]] #提取前8層的輸出
activation_model = models.Model(inputs=model.input, outputs=layer_outputs) #建立一個模型,給定模型的輸入,可以返回這些輸出
這段語句是輸入一張影象,這個模型將返回原始模型的前8層啟用值。這個模型有一個輸入和8個輸出,即每層啟用對應一個輸出。
activations = activation_model.predict(img_tensor) # 返回8個Numpy陣列組成的列表,每個層啟用對應一個Numpy陣列
first_layer_activation = activations[0]
print(first_layer_activation.shape)
(1, 148, 148, 32)
它是大小為148*148的特徵圖,有32個通道。我們來繪製原始模型第3個通道:
import matplotlib.pyplot as plt
plt.matshow(first_layer_activation[0, :, :, 3], cmap='viridis')
plt.show()
再看看第30個通道:
plt.matshow(first_layer_activation[0, :, :, 30], cmap='viridis')
plt.show()
我們可以看到,似乎不同通道對於影象檢測有不同側重,比如第3個通道更側重於邊緣檢測,第30個通道更側於”綠色圓點“檢測。
下面我們來繪製網路中所有啟用的完整視覺化圖。我們需要在8個特徵圖裡的每一個都提取並繪製一個通道,然後將結果疊加在一個大的影象張量中,按通道並排。
import keras
layer_names = []
for layer in model.layers[:8]:
layer_names.append(layer.name) # 用來儲存層的名稱,這樣你就可以把層的名稱畫到圖中
images_per_row = 16
for layer_name, layer_activation in zip(layer_names, activations): # 顯示特徵圖
n_features = layer_activation.shape[-1] # 特徵圖中的特徵個數
size = layer_activation.shape[1] # 特徵圖的形狀為(1, size, size, n_features)
n_cols = n_features // images_per_row # 在這個矩陣中將啟用通道平鋪
display_grid = np.zeros((size * n_cols, images_per_row * size))
for col in range(n_cols): #將每個過濾器平鋪到一個大的水平網格中
for row in range(images_per_row):
channel_image = layer_activation[0,
:, :,
col * images_per_row + row]
channel_image -= channel_image.mean() #對特徵進行後處理,使其看起來更加美觀
channel_image /= channel_image.std()
channel_image *= 64
channel_image += 128
channel_image = np.clip(channel_image, 0, 255).astype('uint8')
display_grid[col * size : (col + 1) * size,
row * size : (row + 1) * size] = channel_image # 顯示網格
scale = 1. / size
plt.figure(figsize=(scale * display_grid.shape[1],
scale * display_grid.shape[0]))
plt.title(layer_name)
plt.grid(False)
plt.imshow(display_grid, aspect='auto', cmap='viridis')
plt.show()
更多精彩內容,歡迎關注我的微信公眾號:資料瞎分析