1. 程式人生 > >Keras+Django多次load model報錯

Keras+Django多次load model報錯

最近在做一個文字分類工具,功能包括上傳樣本,使用樣本訓練model,save訓練好的model並且使用model對文字進行分類。

用到框架有Keras和Django。

訓練階段將訓練好的模型儲存到指定目錄。預測階段載入訓練好的模型進行預測(每一次預測都需要載入模型)。

第一次預測的時候是沒有問題的,可以正常預測,第二次預測的時候報出瞭如下錯誤:

TypeError: Cannot interpret feed_dict key as Tensor: Tensor Tensor(“Placeholde r:0”, shape=(40001,50), dtype=float32) is not an element of this graph.

在網上找了一下,發現很多人都遇到過這個問題,解決方案大體分為兩種:

1. 在載入模型前加上keras.backend.clear_session() 

clear_session()的作用是結束當前的TF計算圖並新建一個。

經過嘗試這種的方法並不能解決我的問題。

2.在初始化載入模型之後,就隨便生成一個向量讓 model 執行一次 predict 函式 

經過嘗試這種方法是可行的。但是,我的專案邏輯是在訓練出模型後才能進行載入,並不能在初始化的時候載入模型,因此,這種方法並不解決的我的問題。

最後是我的解決方法,在訓練階段儲存模型後和預測階段預測結束後手動清空記憶體。

如果還是不行,建議加上keras.backend.clear_session() (上述第一種方法)

新增如下程式碼:

import gc
# training
model = "自己的模型"
model.save('model/model.h5')
del model  #刪除model
gc.collect() #手動清理記憶體


# prediction
model = load_model('model/model.h5')
result = model.predict(test)
del model
gc.collect()