keras 解決載入lstm+crf模型出錯的問題
阿新 • • 發佈:2020-06-11
錯誤展示
new_model = load_model(“model.h5”)
報錯:
1、keras load_model valueError: Unknown Layer :CRF
2、keras load_model valueError: Unknown loss function:crf_loss
錯誤修改
1、load_model修改原始碼:custom_objects = None 改為 def load_model(filepath,custom_objects,compile=True):
2、new_model = load_model(“model.h5”,custom_objects={‘CRF': CRF,‘crf_loss': crf_loss,‘crf_viterbi_accuracy': crf_viterbi_accuracy}
以上修改後,即可執行。
補充知識:用keras搭建bilstm crf
使用 https://github.com/keras-team/keras-contrib實現的crf layer,
安裝 keras-contrib
pip install git+https://www.github.com/keras-team/keras-contrib.git
Code Example:
# coding: utf-8 from keras.models import Sequential from keras.layers import Embedding from keras.layers import LSTM from keras.layers import Bidirectional from keras.layers import Dense from keras.layers import TimeDistributed from keras.layers import Dropout from keras_contrib.layers.crf import CRF from keras_contrib.utils import save_load_utils VOCAB_SIZE = 2500 EMBEDDING_OUT_DIM = 128 TIME_STAMPS = 100 HIDDEN_UNITS = 200 DROPOUT_RATE = 0.3 NUM_CLASS = 5 def build_embedding_bilstm2_crf_model(): """ 帶embedding的雙向LSTM + crf """ model = Sequential() model.add(Embedding(VOCAB_SIZE,output_dim=EMBEDDING_OUT_DIM,input_length=TIME_STAMPS)) model.add(Bidirectional(LSTM(HIDDEN_UNITS,return_sequences=True))) model.add(Dropout(DROPOUT_RATE)) model.add(Bidirectional(LSTM(HIDDEN_UNITS,return_sequences=True))) model.add(Dropout(DROPOUT_RATE)) model.add(TimeDistributed(Dense(NUM_CLASS))) crf_layer = CRF(NUM_CLASS) model.add(crf_layer) model.compile('rmsprop',loss=crf_layer.loss_function,metrics=[crf_layer.accuracy]) return model def save_embedding_bilstm2_crf_model(model,filename): save_load_utils.save_all_weights(model,filename) def load_embedding_bilstm2_crf_model(filename): model = build_embedding_bilstm2_crf_model() save_load_utils.load_all_weights(model,filename) return model if __name__ == '__main__': model = build_embedding_bilstm2_crf_model()
注意:
如果執行build模型報錯,則很可能是keras版本的問題。在keras-contrib==2.0.8且keras==2.0.8時,上面程式碼不會報錯。
以上這篇keras 解決載入lstm+crf模型出錯的問題就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。