1. 程式人生 > 其它 >PyTorch基礎——torch.nn.CrossEntropyLoss交叉熵損失

PyTorch基礎——torch.nn.CrossEntropyLoss交叉熵損失

技術標籤:PyTorch交叉熵損失

本文只考慮基本情況,未考慮加權。

torch.nnCrossEntropyLosss使用的公式

l o s s ( x , c l a s s ) = − l o g ( e x p ( x [ c l a s s ] ∑ j e x p ( x [ j ] ) ) loss(x,class)=-log(\frac {exp(x[class]} {\sum_jexp(x[j])}) loss(x,class)=log(jexp(x[j])exp(x[class]) = − x [ c l a s s ] + l o g ( ∑ j e x p ( x [ j ] ) ) (1) =-x[class]+log(\sum_jexp(x[j])) \tag {1}

=x[class]+log(jexp(x[j]))(1)目標類別採用one-hot編碼
其中,class表示當前樣本類別在one-hot編碼中對應的索引(從0開始),
x[j]表示預測函式的第j個輸出

公式(1)表示先對預測函式使用softmax計算每個類別的概率,再使用log(以e為底)計算後的相反數表示當前類別的損失,只表示其中一個樣本的損失計算方式,非全部樣本。

每個樣本使用one-hot編碼表示所屬類別時,只有一項為1,因此與基本的交叉熵損失函式相比,省略了其它值為0的項,只剩(1)所表示的項。

【sample】

已知條件:共3種類別,輸入兩個樣本,第一個樣本為類別class=0,第二個樣本為類別class=2


預測函式輸出: [ [        0.0541 ,        0.1762 , 0.9489 ] , [ − 0.0288 , − 0.8072 , 0.4909 ] ] [[\;\;\;0.0541, \;\;\;0.1762, 0.9489], \\ \quad \quad \quad\quad\quad\quad[-0.0288, -0.8072, 0.4909]] [[0.0541,0.1762,0.9489],[0.0288,0.8072,0.4909]],shape為2行3列
基於此,計算損失:
首先softmax計算兩個樣本對應類別的概率:
e 0.0541 e 0.0541 + e 0.1762 + e 0.9489 = 0.2185 \frac {e^{0.0541}} {e^{0.0541} + e^{0.1762} + e^{0.9489}} =0.2185
e0.0541+e0.1762+e0.9489e0.0541=0.2185

e 0.4909 e − 0.0288 + e − 0.8072 + e 0.4909 = 0.5354 \frac {e^{0.4909}} {e^{-0.0288} + e^{-0.8072} + e^{0.4909}} =0.5354 e0.0288+e0.8072+e0.4909e0.4909=0.5354
然後計算log之後的相反數:
− l o g ( 0.2185 ) = 1.5210 -log(0.2185) = 1.5210 log(0.2185)=1.5210
− l o g ( 0.5354 ) = 0.6247 -log(0.5354) = 0.6247 log(0.5354)=0.6247
取均值:
1.5210 + 0.6247 2 = 1.073 \frac {1.5210+0.6247}{2}=1.073 21.5210+0.6247=1.073

【torch.nn.CrossEntropyLoss使用流程】

torch.nn.CrossEntropyLoss為一個類,並非單獨一個函式,使用到的相關簡單引數會在使用中說明,並非對所有引數進行說明。

首先建立類物件

In [1]: import torch
In [2]: import torch.nn as nn
In [3]: loss_function = nn.CrossEntropyLoss(reduction="none")

引數reduction預設為"mean",表示對所有樣本的loss取均值,最終返回只有一個值
引數reduction取"none",表示保留每一個樣本的loss

計算損失

In [4]: pred = torch.tensor([[0.0541,0.1762,0.9489],[-0.0288,-0.8072,0.4909]], dtype=torch.float32)
In [5]: class_index = torch.tensor([0, 2], dtype=torch.int64)
In [6]: loss_value = loss_function(pred, class_index)
In [7]: loss_value
Out[7]: tensor([1.5210, 0.6247]) # 與上述【sample】計算一致

實際計算損失值呼叫函式時,傳入pred預測值與class_index類別索引
在傳入每個類別時,class_index應為一維,長度為樣本個數,每個元素表示對應樣本的類別索引,非one-hot編碼方式傳入

【測試torch.nn.CrossEntropyLoss的reduction引數為預設值"mean"】
In [1]: import torch
In [2]: import torch.nn as nn
In [3]: loss_function = nn.CrossEntropyLoss(reduction="mean")
In [4]: pred = torch.tensor([[0.0541,0.1762,0.9489],[-0.0288,-0.8072,0.4909]], dtype=torch.float32)
In [5]: class_index = torch.tensor([0, 2], dtype=torch.int64)
In [6]: loss_value = loss_function(pred, class_index)
In [7]: loss_value
Out[7]: 1.073 # 與上述【sample】計算一致