1. 程式人生 > 程式設計 >TensorFlow固化模型的實現操作

TensorFlow固化模型的實現操作

前言

TensorFlow目前在移動端是無法training的,只能跑已經訓練好的模型,但一般的儲存方式只有單一儲存引數或者graph的,如何將引數、graph同時儲存呢?

生成模型

主要有兩種方法生成模型,一種是通過freeze_graph把tf.train.write_graph()生成的pb檔案與tf.train.saver()生成的chkp檔案固化之後重新生成一個pb檔案,這一種現在不太建議使用。另一種是把變數轉成常量之後寫入PB檔案中。我們簡單的介紹下freeze_graph方法。

freeze_graph

這種方法我們需要先使用tf.train.write_graph()以及tf.train.saver()生成pb檔案和ckpt檔案,程式碼如下:

with tf.Session() as sess:
 saver = tf.train.Saver()
 saver.save(session,"model.ckpt")
 tf.train.write_graph(session.graph_def,'','graph.pb')

然後使用TensorFlow原始碼中的freeze_graph工具進行固化操作:

首先需要build freeze_graph 工具( 需要 bazel ):

bazel build tensorflow/python/tools:freeze_graph

然後使用這個工具進行固化(/path/to/表示檔案路徑):

bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=/path/to/graph.pb --input_checkpoint=/path/to/model.ckpt --output_node_names=output/predict --output_graph=/path/to/frozen.pb
convert_variables_to_constants

其實在TensorFlow中傳統的儲存模型方式是儲存常量以及graph的,而我們的權重主要是變數,如果我們把訓練好的權重變成常量之後再儲存成PB檔案,這樣確實可以儲存權重,就是方法有點繁瑣,需要一個一個呼叫eval方法獲取值之後賦值,再構建一個graph,把W和b賦值給新的graph。

牛逼的Google為了方便大家使用,編寫了一個方法供我們快速的轉換並儲存。

首先我們需要引入這個方法

from tensorflow.python.framework.graph_util import convert_variables_to_constants

在想要儲存的地方加入如下程式碼,把變數轉換成常量

output_graph_def = convert_variables_to_constants(sess,sess.graph_def,output_node_names=['output/predict'])

這裡引數第一個是當前的session,第二個為graph,第三個是輸出節點名(如我的輸出層程式碼是這樣的:)

 with tf.name_scope('output'):
 w_out = tf.Variable(w_alpha * tf.random_normal([1024,MAX_CAPTCHA * CHAR_SET_LEN]))
 tf.summary.histogram('output/weight',w_out)
 b_out = tf.Variable(b_alpha * tf.random_normal([MAX_CAPTCHA * CHAR_SET_LEN]))
 tf.summary.histogram('output/biases',b_out)
 out = tf.add(tf.matmul(dense2,w_out),b_out)
 out = tf.nn.softmax(out)
 predict = tf.argmax(tf.reshape(out,[-1,11,36]),2,name='predict')

由於我們採用了name_scope所以我們在predict之前需要加上output/

生成檔案

with tf.gfile.FastGFile('model/CTNModel.pb',mode='wb') as f:
f.write(output_graph_def.SerializeToString())

第一個引數是檔案路徑,第二個是指檔案操作的模式,這裡指的是以二進位制的方式寫入檔案。

執行程式碼,系統會生成一個PB檔案,接下來我們要測試下這個模型是否能夠正常的讀取、執行。

測試模型

在Python環境下,我們首先需要載入這個模型,程式碼如下:

with open('./model/rounded_graph.pb','rb') as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 output = tf.import_graph_def(graph_def,input_map={'inputs/X:0': newInput_X},return_elements=['output/predict:0'])

由於我們原本的網路輸入值是一個placeholder,這裡為了方便輸入我們也先定義一個新的placeholder:

newInput_X = tf.placeholder(tf.float32,[None,IMAGE_HEIGHT * IMAGE_WIDTH],name="X")

在input_map的引數填入新的placeholder。

在呼叫我們的網路的時候直接用這個新的placeholder接收資料,如:

text_list = sesss.run(output,feed_dict={newInput_X: [captcha_image]})

然後就是執行我們的網路,看是否可以執行吧。

以上這篇TensorFlow固化模型的實現操作就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。