如何使用tensorflow載入keras訓練好的模型
阿新 • • 發佈:2018-12-03
感謝
前言
最近實驗室碰到一個奇怪的需求,大家分別構建不同的NLP模型,最後需要進行整合,可是由於有的同學使用的是keras,有的同學喜歡使用TensorFlow,這樣導致在構建介面時無法統一不同模型load的方式,每一個模型單獨使用一種load的方式的話導致了很多重複開發,效率不高的同時也對專案的可擴充套件性造成了巨大的破壞。於是需要一種能夠統一TensorFlow和keras模型的load過程的方法。
正文
1.構建keras模型
首先假設我們build了一個非常簡單的keras模型,如下所示:
x = np.vstack((np.random.rand(1000,10),-np.random.rand(1000,10))) y = np.vstack((np.ones((1000,1)),np.zeros((1000,1)))) print(x.shape) print(y.shape) model = Sequential() model.add(Dense(units = 32, input_shape=(10,), activation ='relu')) model.add(Dense(units = 16, activation ='relu')) model.add(Dense(units = 1, activation ='sigmoid')) model.compile(loss='binary_crossentropy', optimizer='Adam', metrics=['binary_accuracy']) model.fit(x = x, y=y, epochs = 2, validation_split=0.2)
2.將keras模型儲存為Protocol Buffers的格式
由於TensorFlow是支援將模型儲存為Protocol Buffers(.pb)格式的,如果我們有一種方法能將keras模型儲存為(.pb)格式的話,那我們的問題就解決了。可是天不遂人願,keras沒有直接提供這樣一個將模型儲存為(.pb)格式的方法,所以我們必須自己實現這樣一個方法,如果你看過keras的原始碼的話,你會發現keras backend提供了一個get_session()的函式(只有基於TensorFlow的backend有),該函式會返回一個TensorFlow Session,這樣一來我們就另闢蹊徑,使用這個Session來儲存keras模型,而不使用keras已經提供的儲存模型的函式,方法如下:
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
"""
將輸入的Session儲存為靜態的計算圖結構.
建立一個新的計算圖,其中的節點以及權重和輸入的Session相同. 新的計算圖會將輸入Session中不參與計算的部分刪除。
@param session 需要被儲存的Session.
@param keep_var_names 一個記錄了需要被儲存的變數名的list,若為None則預設儲存所有的變數.
@param output_names 計算圖相關輸出的name list.
@param clear_devices 若為True的話會刪除不參與計算的部分,這樣更利於移植,否則可能移植失敗
@return The frozen graph definition.
"""
from tensorflow.python.framework.graph_util import convert_variables_to_constants
graph = session.graph
with graph.as_default():
freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
output_names = output_names or []
output_names += [v.op.name for v in tf.global_variables()]
input_graph_def = graph.as_graph_def()
if clear_devices:
for node in input_graph_def.node:
node.device = ""
frozen_graph = convert_variables_to_constants(session, input_graph_def,
output_names, freeze_var_names)
return frozen_graph
我們通過如下方法呼叫上述函式來儲存模型:
from keras import backend as K
frozen_graph = freeze_session(K.get_session(),
output_names=[out.op.name for out in model.outputs])
tf.train.write_graph(frozen_graph, wkdir, pb_filename, as_text=False)
3.在TensorFlow中載入儲存的模型
載入儲存模型的例子如下:
from tensorflow.python.platform import gfile
with tf.Session() as sess:
# 從(.pb)檔案中載入模型
with gfile.FastGFile(wkdir+'/'+pb_filename,'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
g_in = tf.import_graph_def(graph_def)