1. 程式人生 > >如何使用tensorflow載入keras訓練好的模型

如何使用tensorflow載入keras訓練好的模型

感謝

How to convert your Keras models to Tensorflow

前言

最近實驗室碰到一個奇怪的需求,大家分別構建不同的NLP模型,最後需要進行整合,可是由於有的同學使用的是keras,有的同學喜歡使用TensorFlow,這樣導致在構建介面時無法統一不同模型load的方式,每一個模型單獨使用一種load的方式的話導致了很多重複開發,效率不高的同時也對專案的可擴充套件性造成了巨大的破壞。於是需要一種能夠統一TensorFlow和keras模型的load過程的方法。

正文

1.構建keras模型
首先假設我們build了一個非常簡單的keras模型,如下所示:

x = np.vstack((np.random.rand(1000,10),-np.random.rand(1000,10)))
y = np.vstack((np.ones((1000,1)),np.zeros((1000,1))))
print(x.shape)
print(y.shape)

model = Sequential()
model.add(Dense(units = 32, input_shape=(10,), activation ='relu'))
model.add(Dense(units = 16, activation ='relu'))
model.add(Dense(units = 1, activation ='sigmoid'))

model.compile(loss='binary_crossentropy', optimizer='Adam', metrics=['binary_accuracy'])
model.fit(x = x, y=y, epochs = 2, validation_split=0.2) 

2.將keras模型儲存為Protocol Buffers的格式
由於TensorFlow是支援將模型儲存為Protocol Buffers(.pb)格式的,如果我們有一種方法能將keras模型儲存為(.pb)格式的話,那我們的問題就解決了。可是天不遂人願,keras沒有直接提供這樣一個將模型儲存為(.pb)格式的方法,所以我們必須自己實現這樣一個方法,如果你看過keras的原始碼的話,你會發現keras backend提供了一個get_session()的函式(只有基於TensorFlow的backend有),該函式會返回一個TensorFlow Session,這樣一來我們就另闢蹊徑,使用這個Session來儲存keras模型,而不使用keras已經提供的儲存模型的函式,方法如下:

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    """
    將輸入的Session儲存為靜態的計算圖結構.
    建立一個新的計算圖,其中的節點以及權重和輸入的Session相同. 新的計算圖會將輸入Session中不參與計算的部分刪除。
    @param session 需要被儲存的Session.
    @param keep_var_names 一個記錄了需要被儲存的變數名的list,若為None則預設儲存所有的變數.
    @param output_names 計算圖相關輸出的name list.
    @param clear_devices 若為True的話會刪除不參與計算的部分,這樣更利於移植,否則可能移植失敗
    @return The frozen graph definition.
    """
    from tensorflow.python.framework.graph_util import convert_variables_to_constants
    graph = session.graph
    with graph.as_default():
        freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
        output_names += [v.op.name for v in tf.global_variables()]
        input_graph_def = graph.as_graph_def()
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""
        frozen_graph = convert_variables_to_constants(session, input_graph_def,
                                                      output_names, freeze_var_names)
        return frozen_graph

我們通過如下方法呼叫上述函式來儲存模型:

from keras import backend as K
frozen_graph = freeze_session(K.get_session(),
                              output_names=[out.op.name for out in model.outputs])
tf.train.write_graph(frozen_graph, wkdir, pb_filename, as_text=False)

3.在TensorFlow中載入儲存的模型
載入儲存模型的例子如下:

from tensorflow.python.platform import gfile
with tf.Session() as sess:
    # 從(.pb)檔案中載入模型
    with gfile.FastGFile(wkdir+'/'+pb_filename,'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        sess.graph.as_default()
        g_in = tf.import_graph_def(graph_def)