tensorflow之argmax
阿新 • • 發佈:2018-12-31
argmax :返回矩陣中的最大索引
一維矩陣的例子:
input1 = tf.constant([1.0, 2.0, 3.0])
with tf.Session() as sess:
print(sess.run(tf.argmax(input1)))
3最大,索引一般都是從0開始,所以應該返回2
輸出:
二維矩陣的例子:
input1 = tf.constant([[1.0, 2.0, 3.0],[4.0, 5.0, 6.0]])
with tf.Session() as sess:
print(sess.run(tf.argmax(input1)))
直接上輸出吧:
輸出是一個矩陣
再調整一下引數:
input1 = tf.constant([[1.0, 2.0, 3.0],[4.0, 5.0, 6.0]])
with tf.Session() as sess:
print(sess.run(tf.argmax(input1,axis = 1)))
輸出如下:
第一個shape是3 第二個shape是2
為什麼呢?
三維矩陣的例子:
axis =0情況
input1 = tf.constant([ [[1.0, 2.0, 3.0],[6.0, 5.0, 4.0]], [[10.0, 11.0, 12.0],[9.0, 8.0, 7.0]] ]) print(input1) with tf.Session() as sess: print(sess.run(tf.argmax(input1,axis = 0)))
輸出:
axis =1情況
input1 = tf.constant([
[[1.0, 2.0, 3.0],[6.0, 5.0, 4.0]],
[[10.0, 11.0, 12.0],[9.0, 8.0, 7.0]]
])
print(input1)
with tf.Session() as sess:
print(sess.run(tf.argmax(input1,axis = 1)))
輸出:
axis =2情況
input1 = tf.constant([ [[1.0, 2.0, 3.0],[6.0, 5.0, 4.0]], [[10.0, 11.0, 12.0],[9.0, 8.0, 7.0]] ]) print(input1) with tf.Session() as sess: print(sess.run(tf.argmax(input1,axis = 2)))
輸出:
axis =3的時候,程式崩潰,就是說axis最大是矩陣的維數-1