使用K.function()除錯keras操作
Keras的底層庫使用Theano或TensorFlow,這兩個庫也稱為Keras的後端。無論是Theano還是TensorFlow,都需要提前定義好網路的結構,也就是常說的“計算圖”。
在執行前需要對計算圖編譯,然後才能輸出結果。那這裡面主要有兩個問題,第一是這個圖結構在執行中不能任意更改,比如說計算圖中有一個隱含層,神經元的數量是100,你想動態的修改這個隱含層神經元的數量那是不可以的;第二是除錯困難,keras沒有內建的除錯工具,所以計算圖的中間結果是很難看到的,一旦最終輸出跟預想不一致,很難找到問題所在。
這裡談一談本人除錯keras的一些經驗:
分階段構建你的神經網路
不要一口氣把整個網路全部寫完,這樣很難保證中間結果的正確性。加如一個CNN文字分類模型是這樣的(如下程式碼),應該在加了Embedding層後,停止,列印一下中間結果,看看跟embedding向量能不能對上,輸出的shape對不對。對上了再進行下一步操作。
有的人覺得這樣很浪費時間,但是除非你能一遍寫對,否則你將花上5倍的時間發現錯誤。
# model parameters: embedding_dims = 50 cnn_filters = 100 cnn_kernel_size = 5 dense_hidden_dims = 200 model = Sequential() model.add(Embedding(nb_words,embedding_dims,input_length=maxlen)) model.add(Dropout(0.5)) model.add(Conv1D(cnn_filters,cnn_kernel_size,padding='valid',activation='relu')) model.add(GlobalMaxPooling1D()) model.add(Dense(dense_hidden_dims)) model.add(Dropout(0.5)) model.add(Activation('relu')) model.add(Dense(1)) model.add(Activation('sigmoid')) return model
使用K.function()函式列印中間結果
function函式可以接收傳入資料,並返回一個numpy陣列。使用這個函式我們可以方便地看到中間結果,尤其對於變長輸入的Input。
下面是官方關於function的文件。
function
keras.backend.function(inputs,outputs,updates=None)
例項化 Keras 函式。
引數
inputs: 佔位符張量列表。
outputs: 輸出張量列表。
updates: 更新操作列表。
**kwargs: 需要傳遞給 tf.Session.run 的引數。
返回
輸出值為 Numpy 陣列。
異常
ValueError: 如果無效的 kwargs 被傳入。
example
下面這個例子是列印一個LSTM層的中間結果,值得注意的是這個LSTM的sequence是變長的,可以看到輸出的結果sequence長度分別是64和128
import keras.backend as K from keras.layers import LSTM,Input import numpy as np I = Input(shape=(None,200)) lstm = LSTM(20,return_sequences=True) f = K.function(inputs=[I],outputs=[lstm(I)]) data1 = np.random.random(size=(2,64,200)) print(f([data1])[0].shape) data2 = np.random.random(size=(2,128,200)) print(f([data2])[0].shape) K.clear_session() # (2,20) # (2,20)
其他的除錯技巧
有頻繁張量變換操作的,如dot,mat,reshape等等,記得加一行形狀變化的註釋,如(100, 128)--> (100,64)
可以使用tensorboard檢視網路的引數情況
確保你的資料沒有問題,很多時候輸出不對不是神經網路有問題,而是資料有問題
以上這篇使用K.function()除錯keras操作就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。