tensorflow非標準模型的匯出
阿新 • • 發佈:2018-11-27
摘要
本文記錄部署一個非標準模型(未定義name、未定義placeholder、未定義batchnorm中的train引數)的過程
引言
之前訓練的一個比較好的模型需要部署到實際應用場景中,但從之前訓練時到現在,tensorflow版本已經更新了7、8個,一些藉口已經改變。給部署帶來一定的難度
主題
本文在舊的tensorflow版本上先進行模式匯出和試驗,成功後再部署到新的tensorflow版本上,先採用最基礎的meta方式進行匯入匯出
meta方式
檢視模型所有name程式碼 print('graph names:{}'.format(graph._names_in_use))
查詢最終輸出的op
原模型最終程式碼:
# 全連線層 + Softmax
with tf.variable_scope('logit'):
logits = self._fully_connected(x, self.hps.num_classes)
self._fully_connected
最後一步呼叫了tf.nn.xw_plus_b(x, w, b)
寫一個例子程式,用來檢視預設的輸出op名稱
graph = tf.get_default_graph() w = tf.Variable(np.random.randn(5, 5).astype('float32'),name="w") x = tf.Variable(np.random.randn(5, 5).astype('float32'), name="w2") b = tf.Variable(np.random.randn(5).astype('float32'), name="b") tf.nn.xw_plus_b(w, x, b) print('graph names:{}'.format(graph._names_in_use))
檢視輸出為:
graph names:{'w/read': 1, 'b/assign': 1, 'b': 1, 'w2/read': 1, 'xw_plus_b/matmul': 1, 'w2': 1, 'w2/assign': 1, 'b/initial_value': 1, 'w/assign': 1, 'w/initial_value': 1, 'w': 1, 'b/read': 1, 'xw_plus_b': 1, 'w2/initial_value': 1}
可以看到xw_plus_b
為操作的原始OP,那麼原模型最終輸出op為"logit/xw_plus_b:1"
,模型部署時的輸出程式碼為
op_logit = graph.get_tensor_by_name("logit/xw_plus_b:0")
logits=sess.run(op_logit, feed_dict)
predictions = np.argmax(logits, axis=1)
查詢輸入tensor
寫一個例子程式,用來檢視預設的輸入op名稱
graph = tf.get_default_graph()
image = tf.Variable(np.random.randn(5, 60000).astype('float32'),name="image")
label = tf.Variable(np.random.randn(5).astype('int'), name="label")
data_num = tf.Variable(np.random.randn(5).astype('int'), name="data_num")
data_num, images, sparse_labels = tf.train.shuffle_batch(
[data_num, image, label], batch_size=5, num_threads=2,
capacity=20, min_after_dequeue=10)
print('graph names:{}'.format(graph._names_in_use))
檢視輸出為:
graph names:{'shuffle_batch/tofloat': 1, 'image/read': 1, 'label': 1, 'image': 1, 'shuffle_batch/sub': 1, 'shuffle_batch/const': 1, 'label/assign': 1, 'label/read': 1, 'shuffle_batch/random_shuffle_queue_close': 2, 'shuffle_batch/maximum': 1, 'shuffle_batch/random_shuffle_queue_enqueue': 1, 'data_num': 1, 'shuffle_batch/mul/y': 1, 'image/initial_value': 1, 'shuffle_batch/random_shuffle_queue': 1, 'shuffle_batch/sub/y': 1, 'data_num/assign': 1, 'shuffle_batch/random_shuffle_queue_close_1': 1, 'label/initial_value': 1, 'data_num/read': 1, 'image/assign': 1, 'shuffle_batch/random_shuffle_queue_size': 1, 'shuffle_batch/n': 1, 'shuffle_batch/fraction_over_10_of_10_full/tags': 1, 'shuffle_batch': 1, 'shuffle_batch/maximum/x': 1, 'data_num/initial_value': 1, 'shuffle_batch/mul': 1, 'shuffle_batch/fraction_over_10_of_10_full': 1}
可以看到shuffle_batch
為操作的原始OP,那麼原模型輸入tensor為"input/shuffle_batch:1"
,模型部署時的輸出程式碼為
op_logit = graph.get_tensor_by_name("logit/xw_plus_b:0")
logits=sess.run(op_logit, feed_dict)
predictions = np.argmax(logits, axis=1)
驗證能否強行使用feed_dict改變變數的值
寫一個例子程式,用來檢視強行填充feed_dict的效果
graph = tf.get_default_graph()
w = tf.placeholder("float", name="w")
w1 = tf.Variable(5.0, name="w2")
x = tf.Variable(2.0, name="w2")
feed_dict={w:5.0,w1:4.0}
y = tf.multiply(w,x)
y1 = tf.multiply(w1,x)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
ty,ty1=sess.run([y,y1],feed_dict=feed_dict)
print(ty)
print(ty1)
檢視輸出為:
10.0
8.0
可以看到y1
的值為強制填充後計算結果4*2=8
,feeddict有效
在原模型中加入強行填充feed_dict程式碼
images_placeholder = np.random.randn(100, 60000).astype('float32')
feed_dict={test_data: images_placeholder}
batchnorm問題
tensorflow中batch normalization的用法
batchnorm引數需要在訓練時加入程式碼和重新訓練,暫不考慮加入,通過每次預測輸入足量樣本來解決