1. 程式人生 > >TensorFlow argmax函數

TensorFlow argmax函數

網絡 理解 答案 name 一行 數據 tro 之前 計算

之前對這個函數理解一直有誤。

以為是獲取一個張量中,行/列的最大值。

其實他獲取的是行/列中最大值的索引號

tf.argmax(input, axis=None, name=None, dimension=None)

axis:0按列,1按行。

舉個例子

[[1,2,3,4]]

A=[[1,2,3,4]]
tf.argmax(A, 0)
tf.argmax(A, 1)

會得到

[0 0 0 0]

[3]

因為按列算,每一列的第唯一一個數組就是最大的數字,其索引號都是0。所以所有的列返回的都是0,總共有4列所以返回:[0,0,0,0]

按行計算得到的就是這一行4個數字中最大的數字4的索引號:3所以返回[3]

TensorFlow MNIST最佳實踐中計算交叉熵的時候就使用了這個函數:

cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))

在計算單一分類問題的交叉熵時,通常是這樣綁定使用的。

這裏的y_是測試數據的真實值。y是我們神經網絡預測的值。

這裏通過tf.argmax(y_, 1)方法,就獲取到正確答案的序號了。然後再進行交叉熵的計算。

參考鏈接:

http://blog.csdn.net/zj360202/article/details/70259999

http://blog.csdn.net/UESTC_C2_403/article/details/72232807

TensorFlow argmax函數