1. 程式人生 > 其它 >CrossEntropyLoss()分析和在pytorch上的使用與改寫

CrossEntropyLoss()分析和在pytorch上的使用與改寫

技術標籤:深度學習深度學習python演算法pytorch資料探勘

目錄

公式

C r o s s E n t r o p y = − ∑ k = 1 N ( p k ⋅ log ⁡ q k ) CrossEntropy=-\sum_{k=1}^{N}(p_k \cdot \log{q_k}) CrossEntropy=k=1N(pklogqk)
其中 q i q_i qi 為模型預測值,而且是經過了softmax層後的結果。

對於分類任務

有唯一 p k p_k

pk 為1,其它均為0,則上式可簡寫為
C r o s s E n t r o p y = − log ⁡ q k CrossEntropy=-\log{q_k} CrossEntropy=logqk

pytorch中使用torch.nn.CrossEntropyLoss()
  • 見官網 torch.nn.CrossEntropyLoss()
  • input:模型的一組輸出,大小為(batch_size, C),C為分類任務的類別數;即,input[] 有 batch_size
    條,每一條大小為 ( C ),C個值分別代表該條資料屬於C個類的概率。
  • target:這組資料對應的一組真實label,大小(batch_size) ;target只能為一維資料,每個值取 [0,C)中的整數,整數代表類別不表示大小,整數之間關係獨立(類似 one-hot 編碼);如 target[i]=3 代表 input 中的第 i 條資料屬於第 3 類。

實現:

torch.nn.CrossEntropyLoss() 是 nn.LogSoftmax() + nn.NLLLoss()

解釋:

LogSoftmax():就是 log(softmax())
NLLLoss():函式全程是 n e g a t i v e l o g l i k e l i h o o d l o s s negative\ log\ likelihood\ loss negativeloglikelihoodloss
函式表示式為 f ( x , c l a s s ) = − x [ c l a s s ] f(x,class)=-x[class]

f(x,class)=x[class] ,如 f ( [ 10 , 20 , 30 ] , 2 ) = − 30 f([10,20,30],2)=-30 f([10,20,30],2)=30

對於公式 C r o s s E n t r o p y = − ∑ k = 1 N ( p k ⋅ log ⁡ q k ) = − log ⁡ q k CrossEntropy=-\sum_{k=1}^{N}(p_k \cdot \log{q_k})=-\log{q_k} CrossEntropy=k=1N(pklogqk)=logqk
q k {q_k} qk 是 softmax 後的結果, log ⁡ q k \log{q_k} logqk 就代表 log(softmax()),而第二個等號就類似 NLLLoss() 選擇。
所以,如果輸入的target使用了多維資料,控制檯很可能報 NLLLoss() 出錯。

對於分佈任務

建議先閱讀上面的分類任務解釋
torch.nn.CrossEntropyLoss() = nn.LogSoftmax() + nn.NLLLoss()

公式展開為
C r o s s E n t r o p y = − p 1 log ⁡ q 1 − p 2 log ⁡ q 2 − … − p n log ⁡ q n CrossEntropy=-p_1\log{q_1}-p_2\log{q_2}-\ldots -p_n\log{q_n} CrossEntropy=p1logq1p2logq2pnlogqn

由公式可知,我們仍然需要 LogSoftmax() 的結果,但不需要 nn.NLLLoss(),因此新寫一個 loss 函式;

def myCrossEntropyLoss(self, outputs, labels):
    batch_size = outputs.size()[0]       
    
    # 對每一行資料進行 logsoftmax 並取負,即公式中的 -log(q)
    outputs = -F.log_softmax(outputs, dim=1) 
    
    # logsoftmax 後的資料和 label 的概率點乘,即公式中的 -p*log(q)
    res = outputs * labels 
    
    # 求平均loss,這裡7是我訓練的類別數
    return torch.sum(res)/(batch_size*7)

參考

Cross-entropy for classification 【 !推薦閱讀 !】
Pytorch裡的CrossEntropyLoss詳解【 !推薦閱讀 !】
手寫CrossEntropyLoss【 !推薦閱讀 !】
cross-entropy-loss-explanation
pytorch官網log_softmax()
pytorch官網torch.nn.CrossEntropyLoss()
logsoftmax 和 softmax 的區別
Pytorch中Softmax和LogSoftmax的使用
tensor乘法
tensor求平均