1. 程式人生 > >Tensorflow:sess.run():引數 feed_dict等作用

Tensorflow:sess.run():引數 feed_dict等作用

feed_dict引數的作用是替換圖中的某個tensor的值。例如:

a = tf.add(2, 5)
b = tf.multiply(a, 3)
with tf.Session() as sess: 
    sess.run(b)

21

replace_dict = {a: 15}
sess.run(b, feed_dict = replace_dict)

45
這樣做的好處是在某些情況下可以避免一些不必要的計算。

除此之外,feed_dict還可以用來設定graph的輸入值,這就引入了

x = tf.placeholder(tf.float32, shape=(1
, 2)) w1 = tf.Variable(tf.random_normal([2, 3],stddev=1,seed=1)) w2 = tf.Variable(tf.random_normal([3, 1],stddev=1,seed=1)) a = tf.matmul(x,w1) y = tf.matmul(a,w2) with tf.Session() as sess: # 變數執行前必須做初始化操作 init_op = tf.global_variables_initializer() sess.run(init_op) print(sess.run
(y, feed_dict={x:[[0.7, 0.5]]}))

[[3.0904665]]

或者 多輸入

x = tf.placeholder(tf.float32, shape=(None, 2))
w1 = tf.Variable(tf.random_normal([2,3],stddev=1,seed=1))
w2 = tf.Variable(tf.random_normal([3,1],stddev=1,seed=1))

a = tf.matmul(x,w1)
y = tf.matmul(a,w2)

with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    print(sess.run(y, feed_dict={x:[[0.7
,0.5],[0.2,0.3],[0.3,0.4],[0.4,0.5]]})) print(sess.run(w1)) print(sess.run(w2))

[[3.0904665]
[1.2236414]
[1.7270732]
[2.2305048]]
[[-0.8113182 1.4845988 0.06532937]
[-2.4427042 0.0992484 0.5912243 ]]
[[-0.8113182 ]
[ 1.4845988 ]
[ 0.06532937]]

注意:此時的a不是一個tensor,而是一個placeholder。我們定義了它的type和shape,但是並沒有具體的值。在後面定義graph的程式碼中,placeholder看上去和普通的tensor物件一樣。在執行程式的時候我們用feed_dict的方式把具體的值提供給placeholder,達到了給graph提供input的目的。
placeholder有點像在定義函式的時候用到的引數。我們在寫函式內部程式碼的時候,雖然用到了引數,但並不知道引數所代表的值。只有在呼叫函式的時候,我們才把具體的值傳遞給引數。