1. 程式人生 > >Keras訓練輔助工具及優化工具

Keras訓練輔助工具及優化工具

原文:http://blog.csdn.net/zzulp/article/details/76591341

1 Callbacks

Callbacks提供了一系列的類,用於在訓練過程中被回撥,從而實現對訓練過程進行觀察和干涉。除了庫提供的一些類,使用者也可以自定義類。下面列舉比較有用的回撥類。

類名 作用 建構函式
ModelCheckpoint 用於在epoch間儲存要模型 ModelCheckpoint(filepath, monitor=’val_loss’, save_best_only=False, save_weights_only=False, mode=’auto’, period=1)
EarlyStopping 當early stop被啟用(如發現loss相比上一個epoch訓練沒有下降),則經過patience個epoch後停止訓練。 EarlyStopping(monitor=’val_loss’, patience=0, mode=’auto’)
TensorBoard 生成tb需要的日誌 TensorBoard(log_dir=’./logs’, histogram_freq=0, write_graph=True, write_images=False, embeddings_freq=0, embeddings_layer_names=None, embeddings_metadata=None)
ReduceLROnPlateau 當指標變化小時,減少學習率 ReduceLROnPlateau(monitor=’val_loss’, factor=0.1, patience=10, mode=’auto’, epsilon=0.0001, cooldown=0, min_lr=0)

示例:

from keras.callbacks import ModelCheckpoint

model = Sequential()
model.add(Dense(10, input_dim=784, kernel_initializer='uniform'))
model.add(Activation('softmax'
)) model.compile(loss='categorical_crossentropy', optimizer='rmsprop') checkpointer = ModelCheckpoint(filepath="/tmp/weights.h5", save_best_only=True) tensbrd = TensorBoard(logdir='path/of/log') model.fit(X_train, Y_train, batch_size=128, callbacks=[checkpointer,tensbrd])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

PS:加入tensorboard回撥類後,就可以使用tensorflow的tensorboard命令列來開啟視覺化web服務了。

2 Application

本模組提供了基於image-net預訓練好的影象模型,方便我們進行遷移學習使用。初次使用時,模型權重資料會下載到~/.keras/models目錄下。

影象模型 說明 建構函式
InceptionV3 InceptionV3(include_top=True, weights=’imagenet’,input_tensor=None,input_shape=None,pooling=None,classes=1000)
ResNet50 ResNet50(include_top=True, weights=’imagenet’,input_tensor=None,input_shape=None,pooling=None,classes=1000)
VGG19 VGG19(include_top=True, weights=’imagenet’,input_tensor=None,input_shape=None,pooling=None,classes=1000)
VGG16 VGG16(include_top=True, weights=’imagenet’,input_tensor=None,input_shape=None,pooling=None,classes=1000)
Xception Xception(include_top=True, weights=’imagenet’,input_tensor=None,input_shape=None,pooling=None, classes=1000)

引數說明

引數 說明
include_top 是否保留頂層的全連線網路, False為只要bottleneck
weights ‘imagenet’代表載入預訓練權重, None代表隨機初始化
input_tensor 可填入Keras tensor作為模型的影象輸出tensor
input_shape 長為3的tuple,指明輸入圖片的shape,圖片的寬高必須大於197
pooling 特徵提取網路的池化方式。None代表不池化,最後一個卷積層的輸出為4D張量。‘avg’代表全域性平均池化,‘max’代表全域性最大值池化
classes 圖片分類的類別數,當include_top=True weight=None時可用

關於遷移學習,可以參考這篇文章:如何在極小資料集上實現影象分類。裡面介紹了通過影象變換以及使用已有模型並fine-tune新分類器的過程。

3 模型視覺化

utils包中提供了plot_model函式,用來將一個model以影象的形式展現出來。此功能依賴pydot-ng與graphviz。 
pip install pydot-ng graphviz

from keras.utils import plot_model
model = keras.applications.InceptionV3()
plot_model(model, to_file='model.png')
  • 1
  • 2
  • 3
  • 1
  • 2
  • 3