1. 程式人生 > >關於TensorFlow中的多圖(Multiple Graphs)

關於TensorFlow中的多圖(Multiple Graphs)

一、摘要

   TensorFlow中的圖(Graph)是眾多操作(Ops)的集合,它描述了具體的操作型別與各操作之間的關聯。在實際應用中,我們可以直接把圖理解為神經網路(Neural Network)結構的程式化描述。TensorFlow中的會話(Session)則實現圖中所有操作,使得資料(Tensor型別)在圖中流動(Flow)起來。平常學習或使用TensorFlow中,我們基本是構築一個圖,開啟一個會話,然後run。但最近由於工作需要,我探索了多圖(Multiple Graphs)方式。本文主要任務是簡單記錄學習過程,以備複閱。如果能給看官提供一絲幫助,那也是極好。

二、多圖實現

(交代一下平臺版本:PyCharm Community Edition 2016.3.2、Python3.5.2、tensorflow0.12.1)
片段一:多圖的建立與預設圖

import tensorflow as tf
import numpy as np
g1 = tf.Graph()  #建立圖1
g2 = tf.Graph()  #建立圖2
print('tf.get_default_graph()=',tf.get_default_graph())#獲取預設圖,並顯示基本資訊
print('g1                    =',g1)
print('g2                    =',g2)
print('------------------------------------')
片段一執行結果:
tf.get_default_graph() = <tensorflow.python.framework.ops.Graph object at 0x000001AACC965550>
g1              = <tensorflow.python.framework.ops.Graph object at 0x000001AAC7C5FC50>
g2              = <tensorflow.python.framework.ops.Graph object at 0x000001AAC7C5FD68>
------------------------------------
從結果地址可以看出:預設圖自動存在,手動建立的圖與預設圖完全不同


片段二:建立各圖下的具體操作(Op)
with g1.as_default():#在with模組中,g1作為預設圖
    x_data = np.random.rand(100).astype(np.float32)#定義圖中具體操作
    y_data = x_data * 0.1 + 0.3
    W = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
    b = tf.Variable(tf.zeros([1]))
    y = W * x_data + b
    loss = tf.reduce_mean(tf.square(y - y_data))
    print('num-of-trainable_variables=', len(tf.trainable_variables()), ' num-of-global_variables=',len(tf.global_variables()))#統計變數個數
    print('g1                    =',g1)
    print('tf.get_default_graph()=',tf.get_default_graph())
print('tf.get_default_graph()=',tf.get_default_graph())
W2 = tf.Variable(tf.random_uniform([1], -1.0, 1.0))#with模組外定義操作,比較模組內外變數個數變化情況
print('num-of-trainable_variables=',len(tf.trainable_variables()),' num-of-global_variables=',len(tf.global_variables()))
print('------------------------------------')
片段二執行結果:
num-of-trainable_variables= 2  num-of-global_variables= 2
g1              = <tensorflow.python.framework.ops.Graph object at 0x000001AAC7C5FC50>
tf.get_default_graph() = <tensorflow.python.framework.ops.Graph object at 0x000001AAC7C5FC50>
tf.get_default_graph() = <tensorflow.python.framework.ops.Graph object at 0x000001AACC965550>
num-of-trainable_variables= 1  num-of-global_variables= 1
------------------------------------
由執行結果可以看出:
a、如上所述,所建操作依附於預設圖 。使用with模組,在模組中讓具體的圖作為預設圖。
b、退出with g.as_default()模組,原始預設圖立馬恢復(系統用棧來進行管理)
c、類似於tf.trainable_variables()、tf.global_variables()等都只針對此刻預設圖裡的變數,編寫時要小心。

片段三:在各圖下建立會話進行計算
with g1.as_default():
    sess1 = tf.Session(graph=g1)
    print('sess1',sess1)
    init = tf.global_variables_initializer()
    sess1.run(init)
    train = tf.train.GradientDescentOptimizer(0.5).minimize(loss)
    for step in range(201):
        sess1.run(train)
        if step % 100 == 0:
            print(step, sess1.run(W), sess1.run(b))


with g2.as_default():
    w = tf.Variable(1.0)
    b = tf.Variable(1.5)
    wb=w+b
    sess2 = tf.Session(graph=g2)
    sess2.run(tf.global_variables_initializer())
print(sess2.run(wb))#定義並初始化後,可以在模組外執行
print('sess2',sess2)
片段三的執行結果如下:
sess1 <tensorflow.python.client.session.Session object at 0x000001AACCA168D0>
0 [-0.41363323] [ 0.84654742]
100 [ 0.09930082] [ 0.30039138]
200 [ 0.09999922] [ 0.30000046]
2.5
sess2 <tensorflow.python.client.session.Session object at 0x000001AACC9FBE80>

由執行結果可以看出:在各自的圖下建立各自的會話進行計算各不干擾

博文就到此結束,看官若有疑問,歡迎留言!