tensorflow tf.argmax() 用法 例子
轉自:https://blog.csdn.net/Jiaach/article/details/78874704
argmax()官方文件如下:
tf.argmax(input, dimension, name=None)
Returns the index with the largest value across dimensions of a tensor.
Args:
input: A Tensor. Must be one of the following types: float32, float64, int64, int32, uint8, int16, int8, complex64, qint8, quint8, qint32.
dimension: A Tensor of type int32. int32, 0 <= dimension < rank(input). Describes which dimension of the input Tensor to reduce across. For vectors, use dimension = 0.
name: A name for the operation (optional).
Returns:
A Tensor of type int64.
dimension=0 按列找
dimension=1 按行找
tf.argmax()返回最大數值的下標
通常和tf.equal()一起使用,計算模型準確度
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
栗子
>>> import tensorflow as tf >>> a = tf.constant([1.,2.,3.,0.,9.,]) >>> b = tf.constant([[1,2,3],[3,2,1],[4,5,6],[6,5,4]]) >>> with tf.Session() as sess: ... sess.run(tf.argmax(a, 0)) Output: 4 >>> with tf.Session() as sess: ... sess.run(tf.argmax(b, 0)) Output: array([3, 2, 2]) >>> with tf.Session() as sess: ... sess.run(tf.argmax(b, 1)) Output: array([2, 0, 2, 0])
Ref:
API文件
--------------------- 本文來自 Jaichg 的CSDN 部落格 ,全文地址請點選:https://blog.csdn.net/Jiaach/article/details/78874704?utm_source=copy