1. 程式人生 > >tensorlfow常用API解讀——tf.argmax()

tensorlfow常用API解讀——tf.argmax()

其實該函式很簡單——輸出Tensor沿著給定axis的最大值的索引。
官方文件中說axis要為Tensor,但是實際上直接給定實數就可以。

import tensorflow as tf
import numpy as np

A = np.array([[31, 23,  4, 24, 27, 34],
            [18,  3, 25,  0,  6, 35],
            [28, 14, 33, 22, 20,  8],
            [13, 30, 21, 19,  7,  9],
            [16,  1, 26, 32,  2, 29
], [17, 12, 5, 11, 10, 15]] ) B=np.random.randint(low=0, high=100, size=[2,3,4]) A_axis_0=tf.argmax(A,axis=0) # AA=tf.argmax(A,axis=tf.constant(0)) # 該句效果同上,之所以會列出來,是因為tf官網中要求axis的資料型別是Tensor,還不清楚直接寫的實數算不算Tensor A_axis_None=tf.argmax(A) # 不指定axis的話,預設axis=
0 A_axis_1=tf.argmax(A,axis=0) with tf.Session() as sess: print("A_axis_0:\n", sess.run(A_axis_0)) print("A_axis_None:\n", sess.run(A_axis_None)) print("A_axis_1:\n", sess.run(A_axis_1)) np.random.seed(1) # 指定隨機種子,讓結果不變 B=np.random.randint(low=0, high=100, size=
[2,3,4], ) print("B:\n",B) B_axis_0=tf.argmax(B,axis=0) # BB=tf.argmax(B,axis=tf.constant(0)) # 該句效果同上,之所以會列出來,是因為tf官網中要求axis的資料型別是Tensor,還不清楚直接寫的實數算不算Tensor B_axis_None=tf.argmax(B) # 不指定axis的話,預設axis=0 B_axis_1=tf.argmax(B,axis=1) B_axis_2=tf.argmax(B,axis=2) with tf.Session() as sess: print("B_axis_0:\n", sess.run(B_axis_0)) print("B_axis_None:\n", sess.run(B_axis_None)) print("B_axis_1:\n", sess.run(B_axis_1)) print("B_axis_2:\n", sess.run(B_axis_2))
A_axis_0:
 [0 3 2 4 0 1]
A_axis_None:
 [0 3 2 4 0 1]
A_axis_1:
 [0 3 2 4 0 1]
B:
 [[[37 12 72  9]
  [75  5 79 64]
  [16  1 76 71]]

 [[ 6 25 50 20]
  [18 84 11 28]
  [29 14 50 68]]]
B_axis_0:
 [[0 1 0 1]
 [0 1 0 0]
 [1 1 0 0]]
B_axis_None:
 [[0 1 0 1]
 [0 1 0 0]
 [1 1 0 0]]
B_axis_1:
 [[1 0 1 2]
 [2 1 0 2]]
B_axis_2:
 [[2 2 2]
 [2 1 3]]