PyTorch基礎——torch.nn.CrossEntropyLoss交叉熵損失
本文只考慮基本情況,未考慮加權。
torch.nnCrossEntropyLosss使用的公式
l
o
s
s
(
x
,
c
l
a
s
s
)
=
−
l
o
g
(
e
x
p
(
x
[
c
l
a
s
s
]
∑
j
e
x
p
(
x
[
j
]
)
)
loss(x,class)=-log(\frac {exp(x[class]} {\sum_jexp(x[j])})
loss(x,class)=−log(∑jexp(x[j])exp(x[class])
=
−
x
[
c
l
a
s
s
]
+
l
o
g
(
∑
j
e
x
p
(
x
[
j
]
)
)
(1)
=-x[class]+log(\sum_jexp(x[j])) \tag {1}
其中,class表示當前樣本類別在one-hot編碼中對應的索引(從0開始),
x[j]表示預測函式的第j個輸出
公式(1)表示先對預測函式使用softmax計算每個類別的概率,再使用log(以e為底)計算後的相反數表示當前類別的損失,只表示其中一個樣本的損失計算方式,非全部樣本。
每個樣本使用one-hot編碼表示所屬類別時,只有一項為1,因此與基本的交叉熵損失函式相比,省略了其它值為0的項,只剩(1)所表示的項。
【sample】
已知條件:共3種類別,輸入兩個樣本,第一個樣本為類別class=0,第二個樣本為類別class=2
預測函式輸出: [ [ 0.0541 , 0.1762 , 0.9489 ] , [ − 0.0288 , − 0.8072 , 0.4909 ] ] [[\;\;\;0.0541, \;\;\;0.1762, 0.9489], \\ \quad \quad \quad\quad\quad\quad[-0.0288, -0.8072, 0.4909]] [[0.0541,0.1762,0.9489],[−0.0288,−0.8072,0.4909]],shape為2行3列
基於此,計算損失:
首先softmax計算兩個樣本對應類別的概率:
e 0.0541 e 0.0541 + e 0.1762 + e 0.9489 = 0.2185 \frac {e^{0.0541}} {e^{0.0541} + e^{0.1762} + e^{0.9489}} =0.2185
e 0.4909 e − 0.0288 + e − 0.8072 + e 0.4909 = 0.5354 \frac {e^{0.4909}} {e^{-0.0288} + e^{-0.8072} + e^{0.4909}} =0.5354 e−0.0288+e−0.8072+e0.4909e0.4909=0.5354
然後計算log之後的相反數:
− l o g ( 0.2185 ) = 1.5210 -log(0.2185) = 1.5210 −log(0.2185)=1.5210
− l o g ( 0.5354 ) = 0.6247 -log(0.5354) = 0.6247 −log(0.5354)=0.6247
取均值:
1.5210 + 0.6247 2 = 1.073 \frac {1.5210+0.6247}{2}=1.073 21.5210+0.6247=1.073
【torch.nn.CrossEntropyLoss使用流程】
torch.nn.CrossEntropyLoss為一個類,並非單獨一個函式,使用到的相關簡單引數會在使用中說明,並非對所有引數進行說明。
首先建立類物件
In [1]: import torch
In [2]: import torch.nn as nn
In [3]: loss_function = nn.CrossEntropyLoss(reduction="none")
引數reduction預設為"mean",表示對所有樣本的loss取均值,最終返回只有一個值
引數reduction取"none",表示保留每一個樣本的loss
計算損失
In [4]: pred = torch.tensor([[0.0541,0.1762,0.9489],[-0.0288,-0.8072,0.4909]], dtype=torch.float32)
In [5]: class_index = torch.tensor([0, 2], dtype=torch.int64)
In [6]: loss_value = loss_function(pred, class_index)
In [7]: loss_value
Out[7]: tensor([1.5210, 0.6247]) # 與上述【sample】計算一致
實際計算損失值呼叫函式時,傳入pred預測值與class_index類別索引
在傳入每個類別時,class_index應為一維,長度為樣本個數,每個元素表示對應樣本的類別索引,非one-hot編碼方式傳入
【測試torch.nn.CrossEntropyLoss的reduction引數為預設值"mean"】
In [1]: import torch
In [2]: import torch.nn as nn
In [3]: loss_function = nn.CrossEntropyLoss(reduction="mean")
In [4]: pred = torch.tensor([[0.0541,0.1762,0.9489],[-0.0288,-0.8072,0.4909]], dtype=torch.float32)
In [5]: class_index = torch.tensor([0, 2], dtype=torch.int64)
In [6]: loss_value = loss_function(pred, class_index)
In [7]: loss_value
Out[7]: 1.073 # 與上述【sample】計算一致