1. 程式人生 > 程式設計 >keras 特徵圖視覺化例項(中間層)

keras 特徵圖視覺化例項(中間層)

鑑於最近一段時間一直在折騰的CNN網路效果不太理想,主要目標是為了檢測出影象中的一些關鍵點,可以參考人臉的關鍵點檢測演算法。

但是由於從資料集的製作是自己完成的,所以資料集質量可能有待商榷,訓練效果不好的原因可能也是因為資料集沒有製作好(標點實在是太累了)。

於是想看看自己做的資料集在進入到網路後那些中間的隱藏層到底發生了哪些變化。

今天主要是用已經訓練好的mnist模型來提前測試一下,這裡的mnist模型的準確度已經達到了98%左右。

使用的比較簡單的一個模型:

def simple_cnn():
 input_data = Input(shape=(28,28,1))
 x = Conv2D(64,kernel_size=3,padding='same',activation='relu',name='conv1')(input_data)
 x = MaxPooling2D(pool_size=2,strides=2,name='maxpool1')(x)
 x = Conv2D(32,name='conv2')(x)
 x = MaxPooling2D(pool_size=2,name='maxpool2')(x)
 x = Dropout(0.25)(x)
 # 獲得最後一層卷積層的輸出
 # 新增自己的全連線
 x = Flatten(name='flatten')(x)
 x = Dense(128,name='fc1')(x)
 x = Dropout(0.25)(x)
 x = Dense(10,activation='softmax',name='fc2')(x)
 model = Model(inputs=input_data,outputs=x)

此模型已經訓練好了,跑了10個epoch,驗證集0.33

這裡的效果還是很好的,┓( ´∀` )┏

下面在網上搞了張手寫數字

使用網路進行預測,這裡就先給出如何視覺化第一層的卷積層的輸出吧,哇哈哈

程式碼:

input_data = Input(shape=(28,name='maxpool2')(x)
 x = Dropout(0.25)(x)
 x = Flatten(name='flatten')(x)
 x = Dense(128,outputs=x)
 
 model.load_weights('final_model_mnist_2019_1_28.h5')
 
 raw_img = cv2.imread('test.png')
 test_img = load_img('test.png',color_mode='grayscale',target_size=(28,28))
 test_img = np.array(test_img)
 test_img = np.expand_dims(test_img,axis=0)
 test_img = np.expand_dims(test_img,axis=3)
 
 conv1_layer = Model(inputs=input_data,outputs=model.get_layer(index=1).output)
 
 conv1_output = conv1_layer.predict(test_img)
 
 for i in range(64):
  show_img = conv1_output[:,:,i]
  print(show_img.shape)
  show_img.shape = [28,28]
  cv2.imshow('img',show_img)
  cv2.waitKey(0)

核心方法就是通過載入模型後,新建Model,將輸出部分換為你想要檢視的網路層數即可,當然get_layer()包括了name和index兩個引數。最後通過遍歷當前卷積層的所有特徵對映,將每一個都展示出來。就可以了。

以上這篇keras 特徵圖視覺化例項(中間層)就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。