Keras訓練輔助工具及優化工具
阿新 • • 發佈:2019-02-04
原文:http://blog.csdn.net/zzulp/article/details/76591341
1 Callbacks
Callbacks提供了一系列的類,用於在訓練過程中被回撥,從而實現對訓練過程進行觀察和干涉。除了庫提供的一些類,使用者也可以自定義類。下面列舉比較有用的回撥類。
類名 | 作用 | 建構函式 |
---|---|---|
ModelCheckpoint | 用於在epoch間儲存要模型 | ModelCheckpoint(filepath, monitor=’val_loss’, save_best_only=False, save_weights_only=False, mode=’auto’, period=1) |
EarlyStopping | 當early stop被啟用(如發現loss相比上一個epoch訓練沒有下降),則經過patience個epoch後停止訓練。 | EarlyStopping(monitor=’val_loss’, patience=0, mode=’auto’) |
TensorBoard | 生成tb需要的日誌 | TensorBoard(log_dir=’./logs’, histogram_freq=0, write_graph=True, write_images=False, embeddings_freq=0, embeddings_layer_names=None, embeddings_metadata=None) |
ReduceLROnPlateau | 當指標變化小時,減少學習率 | ReduceLROnPlateau(monitor=’val_loss’, factor=0.1, patience=10, mode=’auto’, epsilon=0.0001, cooldown=0, min_lr=0) |
示例:
from keras.callbacks import ModelCheckpoint
model = Sequential()
model.add(Dense(10, input_dim=784, kernel_initializer='uniform'))
model.add(Activation('softmax' ))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
checkpointer = ModelCheckpoint(filepath="/tmp/weights.h5", save_best_only=True)
tensbrd = TensorBoard(logdir='path/of/log')
model.fit(X_train, Y_train, batch_size=128, callbacks=[checkpointer,tensbrd])
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
PS:加入tensorboard回撥類後,就可以使用tensorflow的tensorboard命令列來開啟視覺化web服務了。
2 Application
本模組提供了基於image-net預訓練好的影象模型,方便我們進行遷移學習使用。初次使用時,模型權重資料會下載到~/.keras/models目錄下。
影象模型 | 說明 | 建構函式 |
---|---|---|
InceptionV3 | InceptionV3(include_top=True, weights=’imagenet’,input_tensor=None,input_shape=None,pooling=None,classes=1000) | |
ResNet50 | ResNet50(include_top=True, weights=’imagenet’,input_tensor=None,input_shape=None,pooling=None,classes=1000) | |
VGG19 | VGG19(include_top=True, weights=’imagenet’,input_tensor=None,input_shape=None,pooling=None,classes=1000) | |
VGG16 | VGG16(include_top=True, weights=’imagenet’,input_tensor=None,input_shape=None,pooling=None,classes=1000) | |
Xception | Xception(include_top=True, weights=’imagenet’,input_tensor=None,input_shape=None,pooling=None, classes=1000) |
引數說明
引數 | 說明 |
---|---|
include_top | 是否保留頂層的全連線網路, False為只要bottleneck |
weights | ‘imagenet’代表載入預訓練權重, None代表隨機初始化 |
input_tensor | 可填入Keras tensor作為模型的影象輸出tensor |
input_shape | 長為3的tuple,指明輸入圖片的shape,圖片的寬高必須大於197 |
pooling | 特徵提取網路的池化方式。None代表不池化,最後一個卷積層的輸出為4D張量。‘avg’代表全域性平均池化,‘max’代表全域性最大值池化 |
classes | 圖片分類的類別數,當include_top=True weight=None時可用 |
關於遷移學習,可以參考這篇文章:如何在極小資料集上實現影象分類。裡面介紹了通過影象變換以及使用已有模型並fine-tune新分類器的過程。
3 模型視覺化
utils包中提供了plot_model函式,用來將一個model以影象的形式展現出來。此功能依賴pydot-ng與graphviz。
pip install pydot-ng graphviz
from keras.utils import plot_model
model = keras.applications.InceptionV3()
plot_model(model, to_file='model.png')
- 1
- 2
- 3
- 1
- 2
- 3