tensorflow 分類損失函式問題(有點坑)
tf.nn.softmax_cross_entropy_with_logits(記為f1) 和
tf.nn.sparse_softmax_cross_entropy_with_logits(記為f3),以及
tf.nn.softmax_cross_entropy_with_logits_v2(記為f2)
之間的區別。
f1和f3對於引數logits的要求都是一樣的,即未經處理的,直接由神經網路輸出的數值, 比如 [3.5,2.1,7.89,4.4]。兩個函式不一樣的地方在於labels格式的要求,f1的要求labels的格式和logits類似,比如[0,0,1,0]。而f3的要求labels是一個數值,這個數值記錄著ground truth所在的索引。以[0,0,1,0]為例,這裡真值1的索引為2。所以f3要求labels的輸入為數字2(tensor)。一般可以用tf.argmax()來從[0,0,1,0]中取得真值的索引。
f1和f2之間很像,實際上官方文件已經標記出f1已經是deprecated 狀態,推薦使用f2。兩者唯一的區別在於f1在進行反向傳播的時候,只對logits進行反向傳播,labels保持不變。而f2在進行反向傳播的時候,同時對logits和labels都進行反向傳播,如果將labels傳入的tensor設定為stop_gradients,就和f1一樣了。
那麼問題來了,一般我們在進行監督學習的時候,labels都是標記好的真值,什麼時候會需要改變label?f2存在的意義是什麼?實際上在應用中labels並不一定都是人工手動標註的,有的時候還可能是神經網路生成的,一個實際的例子就是對抗生成網路(GAN)。
測試用程式碼:
import tensorflow as tf
import numpy as np
Truth = np.array([0,0,1,0])
Pred_logits = np.array([3.5,2.1,7.89,4.4])
loss = tf.nn.softmax_cross_entropy_with_logits(labels=Truth,logits=Pred_logits)
loss2 = tf.nn.softmax_cross_entropy_with_logits_v2(labels=Truth,logits=Pred_logits)
loss3 = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(Truth),logits=Pred_logits)
with tf.Session() as sess:
print(sess.run(loss))
print(sess.run(loss2))
print(sess.run(loss3))
參考:
https://www.tensorflow.org/api_docs/
https://stats.stackexchange.com/questions/327348/how-is-softmax-cross-entropy-with-logits-different-from-softmax-cross-entropy-wi
---------------------
作者:史丹利複合田
來源:CSDN
原文:https://blog.csdn.net/tsyccnh/article/details/81069308
版權宣告:本文為博主原創文章,轉載請附上博文連結!