tensorflow實現在函式中用tf.Print輸出中間值
tensorflow由於其基於靜態圖的模式,導致寫程式碼的時候很難除錯,除了用官方的除錯工具外,最直接的方法就是把中間結果輸出出來檢視,然而,直接用print函式只能輸出tensor變數的形狀,而不是數值,想要輸出tensor的具體數值需要用tf.Print函式。網上有很多關於這個函式使用方法的說明,這裡簡要介紹:
Print( input_,data,message=None,first_n=None,summarize=None,name=None )
引數:
input_:通過這個操作的張量。 (流入的資料流)
data:計算 op 時要列印的張量列表。(用[ ]引起來的一串需要列印的東西,用逗號隔開)
message:一個字串,錯誤訊息的字首。
first_n:只記錄 first_n 次數。負數日誌,這是預設的。
summarize:只打印每個張量的固定數目的條目。如果沒有,則每個輸入張量最多列印3個元素。
name:操作的名稱(可選)
然而網上大部分資源都是介紹如何在主函式中先建立一個op,再開啟一個Session執行sess.run(op)的方法,但是如果想要輸出函式中的中間值而該值又未傳回主函式呢?這種情況下無法在函式中開啟一個新的Session,但是仍然可以用tf.Print建立op來實現。
import tensorflow as tf import os os.environ["CUDA_VISIBLE_DEVICES"] = "0" def test(): a=tf.constant(0) for i in range(10): a_print = tf.Print(a,['a_value: ',a]) a=a_print+1 return a if __name__=='__main__': with tf.Session() as sess: sess.run(test())
執行結果:
a_print可以理解為在圖中新增了一個節點,在後續程式碼中當有別的變數使用了a_print時(如上例a=a_print+1),就會有資料從a_print節點上流過,就會輸出值,而究竟會輸出幾次值呢?這其實並不是看下文中a_print被使用了幾次,而是看資料流要從該節點上流經幾次,可以理解為a_print這個op被“定義”了幾次。
def test(): a=tf.constant(0) a_print = tf.Print(a,a]) for i in range(10): a=a_print+1 return a if __name__=='__main__': with tf.Session() as sess: sess.run(test())
如果把test()函式改成這樣,則執行結果為:
輸出僅被執行了一次,因為a_print這個op只被定義了一次,雖然後面在迴圈裡不斷被a使用,但是資料只從它身上經過了一次,所以只會print一次,並且a_print的值永遠為0,最終返回的a的值也為1。
再把程式碼改成下例:
def test(): a=tf.constant(0) a_print = tf.Print(a,a]) for i in range(10): a_print=a_print+1 return a if __name__=='__main__': with tf.Session() as sess: sess.run(test())
執行結果是什麼也不會輸出,因為a_print這個op沒有和別的變數發生關係,它沒有被別的變數使用,在圖裡為孤立的一個節點,沒有資料流過,就不會被執行。
而如果改成這樣
def test(): a=tf.constant(0) a_print = tf.Print(a,a]) for i in range(10): a_print=a_print+1 return a_print if __name__=='__main__': with tf.Session() as sess: sess.run(test())
執行結果
返回的a_print值為10也是正確的,因為a_print在下文被返回,所以有資料流流經,會被執行,而因為a_print的定義只執行一次,所以只會輸出一次。
以上這篇tensorflow實現在函式中用tf.Print輸出中間值就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。