1. 程式人生 > >tf.argmax是一個耗時逐漸累加的操作

tf.argmax是一個耗時逐漸累加的操作

在呼叫訓練好的模型進行圖片分類時,其中的一個操作是分析logits中數值最大的元素的下標,也就是argmax方法。但在實際執行時發現,隨著分類程序的推進,影象分類耗時越來越長,對各部分操作耗時進行統計,發現問題出在tf.argmax()操作。

為了確保分類效率,可以採用numpy庫的argmax函式代替之,具體如下:

# 原始的argmax操作
indicies = tf.argmax(logits, 1).eval()

# 修改之後的argmax操作
indicies = np.argmax(logits, axis = 1)