淺談keras2 predict和fit_generator的坑
1、使用predict時,必須設定batch_size,否則效率奇低。
檢視keras文件中,predict函式原型:
predict(self,x,batch_size=32,verbose=0)
說明:
只使用batch_size=32,也就是說每次將batch_size=32的資料通過PCI匯流排傳到GPU,然後進行預測。在一些問題中,batch_size=32明顯是非常小的。而通過PCI傳資料是非常耗時的。
所以,使用的時候會發現預測資料時效率奇低,其原因就是batch_size太小了。
經驗:
使用predict時,必須人為設定好batch_size,否則PCI匯流排之間的資料傳輸次數過多,效能會非常低下。
2、fit_generator
說明:keras 中 fit_generator引數steps_per_epoch已經改變含義了,目前的含義是一個epoch分成多少個batch_size。舊版的含義是一個epoch的樣本數目。
如果說訓練樣本樹N=1000,steps_per_epoch = 10,那麼相當於一個batch_size=100,如果還是按照舊版來設定,那麼相當於
batch_size = 1,會效能非常低。
經驗:
必須明確fit_generator引數steps_per_epoch
補充知識:Keras:建立自己的generator(適用於model.fit_generator),解決記憶體問題
為什麼要使用model.fit_generator?
在現實的機器學習中,訓練一個model往往需要數量巨大的資料,如果使用fit進行資料訓練,很有可能導致記憶體不夠,無法進行訓練。
fit_generator的定義如下:
fit_generator(generator,steps_per_epoch=None,epochs=1,verbose=1,callbacks=None,validation_data=None,validation_steps=None,class_weight=None,max_queue_size=10,workers=1,use_multiprocessing=False,shuffle=True,initial_epoch=0)
其中各項的具體解釋,請參考Keras中文文件
我們重點關注的是generator引數:
generator: 一個生成器,或者一個 Sequence (keras.utils.Sequence) 物件的例項, 以在使用多程序時避免資料的重複。 生成器的輸出應該為以下之一:
一個 (inputs,targets) 元組
一個 (inputs,targets,sample_weights) 元組。
那麼,問題來了,如何構建這個generator呢?有以下幾種辦法:
自己建立一個generator生成器
自己定義一個 Sequence (keras.utils.Sequence) 物件
使用Keras自帶的ImageDataGenerator和.flow/.flow_from_dataframe/.flow_from_directory來生成一個generator
1.自己建立一個generator生成器
使用Keras自帶的ImageDataGenerator和.flow/.flow_from_dataframe/.flow_from_directory 靈活度不高,只有當資料集滿足一定格式(例如,按照分類資料夾存放)或者具備一定條件時,使用才使用才較為方便。
此時,自己建立一個generator就很重要了,關於python的generator是什麼原理,怎麼使用,就不加贅述,可以檢視python的基本語法。
此處,我們用yield來返回資料組,標籤組,從而使fit_generator可以呼叫我們的generator來成批處理資料。
具體實現如下:
def myGenerator(batch_size): # loading data X_train,Y_train=load_data(...) # data processing # ................ total_size=X_train.size #batch_size means how many data you want to train one step while 1: for i in range(total_size//batch_size): yield x_train[i*batch_size:(i+1)*batch_size],y[i*batch_size:(i+1)*batch_size] return myGenerator
接著你可以呼叫該生成器:
self._model.fit_generator(myGenerator(batch_size),steps_per_epoch=total_size//batch_size,epochs=epoch_num)
以上這篇淺談keras2 predict和fit_generator的坑就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。