1. 程式人生 > >tensorflow之argmax

tensorflow之argmax

argmax :返回矩陣中的最大索引

一維矩陣的例子:

input1 = tf.constant([1.0, 2.0, 3.0])
with tf.Session() as sess:
	print(sess.run(tf.argmax(input1)))

3最大,索引一般都是從0開始,所以應該返回2

輸出:

 二維矩陣的例子:

input1 = tf.constant([[1.0, 2.0, 3.0],[4.0, 5.0, 6.0]])
with tf.Session() as sess:
	print(sess.run(tf.argmax(input1)))

直接上輸出吧:

 輸出是一個矩陣

再調整一下引數:

input1 = tf.constant([[1.0, 2.0, 3.0],[4.0, 5.0, 6.0]])
with tf.Session() as sess:
	print(sess.run(tf.argmax(input1,axis = 1)))

輸出如下:

 第一個shape是3 第二個shape是2

為什麼呢?

三維矩陣的例子:

axis =0情況

input1 = tf.constant([
[[1.0, 2.0, 3.0],[6.0, 5.0, 4.0]],
[[10.0, 11.0, 12.0],[9.0, 8.0, 7.0]]
])
print(input1)
with tf.Session() as sess:
	print(sess.run(tf.argmax(input1,axis = 0)))

 輸出:

axis =1情況

input1 = tf.constant([
[[1.0, 2.0, 3.0],[6.0, 5.0, 4.0]],
[[10.0, 11.0, 12.0],[9.0, 8.0, 7.0]]
])
print(input1)
with tf.Session() as sess:
	print(sess.run(tf.argmax(input1,axis = 1)))

 輸出:

axis =2情況

input1 = tf.constant([
[[1.0, 2.0, 3.0],[6.0, 5.0, 4.0]],
[[10.0, 11.0, 12.0],[9.0, 8.0, 7.0]]
])
print(input1)
with tf.Session() as sess:
	print(sess.run(tf.argmax(input1,axis = 2)))

 輸出:

 axis =3的時候,程式崩潰,就是說axis最大是矩陣的維數-1