關於tensorflow中的softmax_cross_entropy_with_logits_v2函式的區別
阿新 • • 發佈:2019-01-03
tf.nn.softmax_cross_entropy_with_logits(記為f1) 和
tf.nn.sparse_softmax_cross_entropy_with_logits(記為f3),以及
tf.nn.softmax_cross_entropy_with_logits_v2(記為f2)
之間的區別。
f1和f3對於引數logits的要求都是一樣的,即未經處理的,直接由神經網路輸出的數值, 比如 [3.5,2.1,7.89,4.4]
。兩個函式不一樣的地方在於labels格式的要求,f1的要求labels的格式和logits類似,比如[0,0,1,0]
。而f3的要求labels是一個數值,這個數值記錄著ground truth所在的索引。以[0,0,1,0]
tf.argmax()
來從[0,0,1,0]
中取得真值的索引。
f1和f2之間很像,實際上官方文件已經標記出f1已經是deprecated 狀態,推薦使用f2。兩者唯一的區別在於f1在進行反向傳播的時候,只對logits進行反向傳播,labels保持不變。而f2在進行反向傳播的時候,同時對logits和labels都進行反向傳播,如果將labels傳入的tensor設定為stop_gradients,就和f1一樣了。
那麼問題來了,一般我們在進行監督學習的時候,labels都是標記好的真值,什麼時候會需要改變label?f2存在的意義是什麼?實際上在應用中labels並不一定都是人工手動標註的,有的時候還可能是神經網路生成的,一個實際的例子就是對抗生成網路(GAN)。
測試用程式碼:
import tensorflow as tf
import numpy as np
Truth = np.array([0,0,1,0])
Pred_logits = np.array([3.5,2.1,7.89,4.4])
loss = tf.nn.softmax_cross_entropy_with_logits(labels=Truth,logits=Pred_logits)
loss2 = tf.nn.softmax_cross_entropy_with_logits_v2(labels=Truth,logits=Pred_logits)
loss3 = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(Truth),logits=Pred_logits)
with tf.Session() as sess:
print(sess.run(loss))
print(sess.run(loss2))
print(sess.run(loss3))
參考: