tensorlfow常用API解讀——tf.argmax()
阿新 • • 發佈:2019-01-09
其實該函式很簡單——輸出Tensor沿著給定axis的最大值的索引。
官方文件中說axis要為Tensor,但是實際上直接給定實數就可以。
import tensorflow as tf
import numpy as np
A = np.array([[31, 23, 4, 24, 27, 34],
[18, 3, 25, 0, 6, 35],
[28, 14, 33, 22, 20, 8],
[13, 30, 21, 19, 7, 9],
[16, 1, 26, 32, 2, 29 ],
[17, 12, 5, 11, 10, 15]]
)
B=np.random.randint(low=0,
high=100,
size=[2,3,4])
A_axis_0=tf.argmax(A,axis=0)
# AA=tf.argmax(A,axis=tf.constant(0)) # 該句效果同上,之所以會列出來,是因為tf官網中要求axis的資料型別是Tensor,還不清楚直接寫的實數算不算Tensor
A_axis_None=tf.argmax(A) # 不指定axis的話,預設axis= 0
A_axis_1=tf.argmax(A,axis=0)
with tf.Session() as sess:
print("A_axis_0:\n", sess.run(A_axis_0))
print("A_axis_None:\n", sess.run(A_axis_None))
print("A_axis_1:\n", sess.run(A_axis_1))
np.random.seed(1) # 指定隨機種子,讓結果不變
B=np.random.randint(low=0,
high=100,
size= [2,3,4],
)
print("B:\n",B)
B_axis_0=tf.argmax(B,axis=0)
# BB=tf.argmax(B,axis=tf.constant(0)) # 該句效果同上,之所以會列出來,是因為tf官網中要求axis的資料型別是Tensor,還不清楚直接寫的實數算不算Tensor
B_axis_None=tf.argmax(B) # 不指定axis的話,預設axis=0
B_axis_1=tf.argmax(B,axis=1)
B_axis_2=tf.argmax(B,axis=2)
with tf.Session() as sess:
print("B_axis_0:\n", sess.run(B_axis_0))
print("B_axis_None:\n", sess.run(B_axis_None))
print("B_axis_1:\n", sess.run(B_axis_1))
print("B_axis_2:\n", sess.run(B_axis_2))
A_axis_0:
[0 3 2 4 0 1]
A_axis_None:
[0 3 2 4 0 1]
A_axis_1:
[0 3 2 4 0 1]
B:
[[[37 12 72 9]
[75 5 79 64]
[16 1 76 71]]
[[ 6 25 50 20]
[18 84 11 28]
[29 14 50 68]]]
B_axis_0:
[[0 1 0 1]
[0 1 0 0]
[1 1 0 0]]
B_axis_None:
[[0 1 0 1]
[0 1 0 0]
[1 1 0 0]]
B_axis_1:
[[1 0 1 2]
[2 1 0 2]]
B_axis_2:
[[2 2 2]
[2 1 3]]