keras和tensorflow使用fit_generator 批次訓練操作
fit_generator 是 keras 提供的用來進行批次訓練的函式,使用方法如下:
model.fit_generator(generator,steps_per_epoch=None,epochs=1,verbose=1,callbacks=None,validation_data=None,validation_steps=None,class_weight=None,max_queue_size=10,workers=1,use_multiprocessing=False,shuffle=True,initial_epoch=0)
引數說明:
generator: 一個生成器,或者一個 Sequence (keras.utils.Sequence) 物件的例項, 以在使用多程序時避免資料的重複。 生成器的輸出應該為以下之一:
一個(inputs,targets) 元組
一個 (inputs,targets,sample_weights) 元組。
這個元組(生成器的單個輸出)組成了單個的 batch。 因此,這個元組中的所有陣列長度必須相同(與這一個 batch 的大小相等)。 不同的 batch 可能大小不同。 例如,一個 epoch 的最後一個 batch 往往比其他 batch 要小, 如果資料集的尺寸不能被 batch size 整除。 生成器將無限地在資料集上迴圈。當執行到第steps_per_epoch 時,記一個 epoch 結束。
steps_per_epoch: 在宣告一個 epoch 完成並開始下一個 epoch 之前從 generator產生的總步數(批次樣本)。 它通常應該等於你的資料集的樣本數量除以批量大小。 對於Sequence,它是可選的:如果未指定,將使用len(generator)作為步數。
epochs: 整數。訓練模型的迭代總輪數。一個 epoch 是對所提供的整個資料的一輪迭代,如 steps_per_epoch 所定義。注意,與 initial_epoch 一起使用,epoch 應被理解為「最後一輪」。模型沒有經歷由 epochs 給出的多次迭代的訓練,而僅僅是直到達到索引 epoch 的輪次。
verbose: 0,1 或 2。日誌顯示模式。 0 = 安靜模式,1 = 進度條,2 = 每輪一行。
callbacks: keras.callbacks.Callback 例項的列表。在訓練時呼叫的一系列回撥函式。
validation_data: 它可以是以下之一:
驗證資料的生成器或Sequence例項
一個(inputs,targets) 元組
一個(inputs,sample_weights) 元組。
在每個 epoch 結束時評估損失和任何模型指標。該模型不會對此資料進行訓練。
validation_steps: 僅當 validation_data 是一個生成器時才可用。 在停止前 generator 生成的總步數(樣本批數)。 對於 Sequence,它是可選的:如果未指定,將使用 len(generator) 作為步數。
class_weight: 可選的將類索引(整數)對映到權重(浮點)值的字典,用於加權損失函式(僅在訓練期間)。 這可以用來告訴模型「更多地關注」來自代表性不足的類的樣本。
max_queue_size: 整數。生成器佇列的最大尺寸。 如未指定,max_queue_size 將預設為 10。
workers: 整數。使用的最大程序數量,如果使用基於程序的多執行緒。 如未指定,workers 將預設為 1。如果為 0,將在主執行緒上執行生成器。
use_multiprocessing: 布林值。如果 True,則使用基於程序的多執行緒。 如未指定, use_multiprocessing 將預設為 False。 請注意,由於此實現依賴於多程序,所以不應將不可傳遞的引數傳遞給生成器,因為它們不能被輕易地傳遞給子程序。
shuffle: 是否在每輪迭代之前打亂 batch 的順序。 只能與 Sequence (keras.utils.Sequence) 例項同用。
initial_epoch: 開始訓練的輪次(有助於恢復之前的訓練)。
補充知識:Keras中fit_generator 的多個分支輸入時,需注意generator的格式 以及 輸入序列的順序
需要注意迭代器 yeild返回不能是[x1,x2],y 這樣,而是要完整的字典格式的:
yield ({'input_1': x1,'input_2': x2},{'output': y})
這也不算坑 追進去 fit_generator也能看到示例
def generate_batch(x_train,y_train,batch_size,x_train2,randomFlag=True): ylen = len(y_train) loopcount = ylen // batch_size i=-1 while True: if randomFlag: i = random.randint(0,loopcount-1) else: i=i+1 i=i%loopcount yield ({'lstmInput': x_train[i*batch_size:(i+1)*batch_size],'bgInput': x_train2[i*batch_size:(i+1)*batch_size]},{'prediction': y_train[i*batch_size:(i+1)*batch_size]})
ps: 因為要是tuple yield後的括號不能省
需注意的坑1是,validation data中如果用【】組成陣列進行輸入,是要按順序的,按編譯model前的設定model = Model(inputs=[simInput,lstmInput,bgInput],outputs=predictions),中陣列的順序來編譯
需注意的坑2是,多輸入input時,以後都用 inputs1=Input(batch_shape=(batchSize,TPeriod,dimIn,),name='input1LSTM')指定batchSize,不然跟stateful lstm結合時,會提示不匹配。
history=model.fit_generator(generate_batch(trainX,trainY,batchSize,trainX2),steps_per_epoch=len(trainX)//batchSize,validation_data=([testX,testX2],testY),epochs=epochs,callbacks=[tensorboard,checkpoint],initial_epoch=0,verbose=1) # Fit the LSTM network/擬合LSTM網路
以上這篇keras和tensorflow使用fit_generator 批次訓練操作就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。