關於TensorFlow中的多圖(Multiple Graphs)
阿新 • • 發佈:2019-01-26
一、摘要
TensorFlow中的圖(Graph)是眾多操作(Ops)的集合,它描述了具體的操作型別與各操作之間的關聯。在實際應用中,我們可以直接把圖理解為神經網路(Neural Network)結構的程式化描述。TensorFlow中的會話(Session)則實現圖中所有操作,使得資料(Tensor型別)在圖中流動(Flow)起來。平常學習或使用TensorFlow中,我們基本是構築一個圖,開啟一個會話,然後run。但最近由於工作需要,我探索了多圖(Multiple Graphs)方式。本文主要任務是簡單記錄學習過程,以備複閱。如果能給看官提供一絲幫助,那也是極好。二、多圖實現
(交代一下平臺版本:PyCharm Community Edition 2016.3.2、Python3.5.2、tensorflow0.12.1)片段一:多圖的建立與預設圖
import tensorflow as tf
import numpy as np
g1 = tf.Graph() #建立圖1
g2 = tf.Graph() #建立圖2
print('tf.get_default_graph()=',tf.get_default_graph())#獲取預設圖,並顯示基本資訊
print('g1 =',g1)
print('g2 =',g2)
print('------------------------------------')
片段一執行結果:tf.get_default_graph() = <tensorflow.python.framework.ops.Graph object at 0x000001AACC965550>
g1 = <tensorflow.python.framework.ops.Graph object at 0x000001AAC7C5FC50>
g2 = <tensorflow.python.framework.ops.Graph object at 0x000001AAC7C5FD68>
------------------------------------
從結果地址可以看出:預設圖自動存在,手動建立的圖與預設圖完全不同
片段二:建立各圖下的具體操作(Op)
片段二執行結果:with g1.as_default():#在with模組中,g1作為預設圖 x_data = np.random.rand(100).astype(np.float32)#定義圖中具體操作 y_data = x_data * 0.1 + 0.3 W = tf.Variable(tf.random_uniform([1], -1.0, 1.0)) b = tf.Variable(tf.zeros([1])) y = W * x_data + b loss = tf.reduce_mean(tf.square(y - y_data)) print('num-of-trainable_variables=', len(tf.trainable_variables()), ' num-of-global_variables=',len(tf.global_variables()))#統計變數個數 print('g1 =',g1) print('tf.get_default_graph()=',tf.get_default_graph()) print('tf.get_default_graph()=',tf.get_default_graph()) W2 = tf.Variable(tf.random_uniform([1], -1.0, 1.0))#with模組外定義操作,比較模組內外變數個數變化情況 print('num-of-trainable_variables=',len(tf.trainable_variables()),' num-of-global_variables=',len(tf.global_variables())) print('------------------------------------')
num-of-trainable_variables= 2 num-of-global_variables= 2
g1 = <tensorflow.python.framework.ops.Graph object at 0x000001AAC7C5FC50>
tf.get_default_graph() = <tensorflow.python.framework.ops.Graph object at 0x000001AAC7C5FC50>
tf.get_default_graph() = <tensorflow.python.framework.ops.Graph object at 0x000001AACC965550>
num-of-trainable_variables= 1 num-of-global_variables= 1
------------------------------------
由執行結果可以看出:
a、如上所述,所建操作依附於預設圖 。使用with模組,在模組中讓具體的圖作為預設圖。
b、退出with g.as_default()模組,原始預設圖立馬恢復(系統用棧來進行管理)
c、類似於tf.trainable_variables()、tf.global_variables()等都只針對此刻預設圖裡的變數,編寫時要小心。
片段三:在各圖下建立會話進行計算
with g1.as_default():
sess1 = tf.Session(graph=g1)
print('sess1',sess1)
init = tf.global_variables_initializer()
sess1.run(init)
train = tf.train.GradientDescentOptimizer(0.5).minimize(loss)
for step in range(201):
sess1.run(train)
if step % 100 == 0:
print(step, sess1.run(W), sess1.run(b))
with g2.as_default():
w = tf.Variable(1.0)
b = tf.Variable(1.5)
wb=w+b
sess2 = tf.Session(graph=g2)
sess2.run(tf.global_variables_initializer())
print(sess2.run(wb))#定義並初始化後,可以在模組外執行
print('sess2',sess2)
片段三的執行結果如下:sess1 <tensorflow.python.client.session.Session object at 0x000001AACCA168D0>
0 [-0.41363323] [ 0.84654742]
100 [ 0.09930082] [ 0.30039138]
200 [ 0.09999922] [ 0.30000046]
2.5
sess2 <tensorflow.python.client.session.Session object at 0x000001AACC9FBE80>
由執行結果可以看出:在各自的圖下建立各自的會話進行計算各不干擾
博文就到此結束,看官若有疑問,歡迎留言!