1. 程式人生 > 程式設計 >keras 解決載入lstm+crf模型出錯的問題

keras 解決載入lstm+crf模型出錯的問題

錯誤展示

new_model = load_model(“model.h5”)

報錯:

1、keras load_model valueError: Unknown Layer :CRF

2、keras load_model valueError: Unknown loss function:crf_loss

錯誤修改

1、load_model修改原始碼:custom_objects = None 改為 def load_model(filepath,custom_objects,compile=True):

2、new_model = load_model(“model.h5”,custom_objects={‘CRF': CRF,‘crf_loss': crf_loss,‘crf_viterbi_accuracy': crf_viterbi_accuracy}

以上修改後,即可執行。

補充知識:用keras搭建bilstm crf

使用 https://github.com/keras-team/keras-contrib實現的crf layer,

安裝 keras-contrib

pip install git+https://www.github.com/keras-team/keras-contrib.git

Code Example:

# coding: utf-8
from keras.models import Sequential
from keras.layers import Embedding
from keras.layers import LSTM
from keras.layers import Bidirectional
from keras.layers import Dense
from keras.layers import TimeDistributed
from keras.layers import Dropout
from keras_contrib.layers.crf import CRF
from keras_contrib.utils import save_load_utils

VOCAB_SIZE = 2500
EMBEDDING_OUT_DIM = 128
TIME_STAMPS = 100
HIDDEN_UNITS = 200
DROPOUT_RATE = 0.3
NUM_CLASS = 5

def build_embedding_bilstm2_crf_model():
 """
 帶embedding的雙向LSTM + crf
 """
 model = Sequential()
 model.add(Embedding(VOCAB_SIZE,output_dim=EMBEDDING_OUT_DIM,input_length=TIME_STAMPS))
 model.add(Bidirectional(LSTM(HIDDEN_UNITS,return_sequences=True)))
 model.add(Dropout(DROPOUT_RATE))
 model.add(Bidirectional(LSTM(HIDDEN_UNITS,return_sequences=True)))
 model.add(Dropout(DROPOUT_RATE))
 model.add(TimeDistributed(Dense(NUM_CLASS)))
 crf_layer = CRF(NUM_CLASS)
 model.add(crf_layer)
 model.compile('rmsprop',loss=crf_layer.loss_function,metrics=[crf_layer.accuracy])
 return model

def save_embedding_bilstm2_crf_model(model,filename):
 save_load_utils.save_all_weights(model,filename)

def load_embedding_bilstm2_crf_model(filename):
 model = build_embedding_bilstm2_crf_model()
 save_load_utils.load_all_weights(model,filename)
 return model

if __name__ == '__main__':
 model = build_embedding_bilstm2_crf_model()

注意:

如果執行build模型報錯,則很可能是keras版本的問題。在keras-contrib==2.0.8且keras==2.0.8時,上面程式碼不會報錯。

以上這篇keras 解決載入lstm+crf模型出錯的問題就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。