1. 程式人生 > 程式設計 >tensorflow實現在函式中用tf.Print輸出中間值

tensorflow實現在函式中用tf.Print輸出中間值

tensorflow由於其基於靜態圖的模式,導致寫程式碼的時候很難除錯,除了用官方的除錯工具外,最直接的方法就是把中間結果輸出出來檢視,然而,直接用print函式只能輸出tensor變數的形狀,而不是數值,想要輸出tensor的具體數值需要用tf.Print函式。網上有很多關於這個函式使用方法的說明,這裡簡要介紹:

Print(
 input_,data,message=None,first_n=None,summarize=None,name=None
 )

引數:

input_:通過這個操作的張量。 (流入的資料流)

data:計算 op 時要列印的張量列表。(用[ ]引起來的一串需要列印的東西,用逗號隔開)

message:一個字串,錯誤訊息的字首。

first_n:只記錄 first_n 次數。負數日誌,這是預設的。

summarize:只打印每個張量的固定數目的條目。如果沒有,則每個輸入張量最多列印3個元素。

name:操作的名稱(可選)

然而網上大部分資源都是介紹如何在主函式中先建立一個op,再開啟一個Session執行sess.run(op)的方法,但是如果想要輸出函式中的中間值而該值又未傳回主函式呢?這種情況下無法在函式中開啟一個新的Session,但是仍然可以用tf.Print建立op來實現。

import tensorflow as tf
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

def test():
 a=tf.constant(0)
 for i in range(10): 
  a_print = tf.Print(a,['a_value: ',a])
  a=a_print+1
 return a
 
if __name__=='__main__':
 with tf.Session() as sess:
  sess.run(test())

執行結果:

a_print可以理解為在圖中新增了一個節點,在後續程式碼中當有別的變數使用了a_print時(如上例a=a_print+1),就會有資料從a_print節點上流過,就會輸出值,而究竟會輸出幾次值呢?這其實並不是看下文中a_print被使用了幾次,而是看資料流要從該節點上流經幾次,可以理解為a_print這個op被“定義”了幾次。

def test():
 a=tf.constant(0)
 a_print = tf.Print(a,a])
 for i in range(10): 
  a=a_print+1
 return a
 
if __name__=='__main__':
 with tf.Session() as sess:
  sess.run(test())

如果把test()函式改成這樣,則執行結果為:

輸出僅被執行了一次,因為a_print這個op只被定義了一次,雖然後面在迴圈裡不斷被a使用,但是資料只從它身上經過了一次,所以只會print一次,並且a_print的值永遠為0,最終返回的a的值也為1。

再把程式碼改成下例:

def test():
 a=tf.constant(0)
 a_print = tf.Print(a,a])
 for i in range(10): 
  a_print=a_print+1
 return a
 
if __name__=='__main__':
 with tf.Session() as sess:
  sess.run(test())

執行結果是什麼也不會輸出,因為a_print這個op沒有和別的變數發生關係,它沒有被別的變數使用,在圖裡為孤立的一個節點,沒有資料流過,就不會被執行。

而如果改成這樣

def test():
 a=tf.constant(0)
 a_print = tf.Print(a,a])
 for i in range(10): 
  a_print=a_print+1
 return a_print
 
if __name__=='__main__':
 with tf.Session() as sess:
  sess.run(test())

執行結果

返回的a_print值為10也是正確的,因為a_print在下文被返回,所以有資料流流經,會被執行,而因為a_print的定義只執行一次,所以只會輸出一次。

以上這篇tensorflow實現在函式中用tf.Print輸出中間值就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。