CrossEntropyLoss()分析和在pytorch上的使用與改寫
技術標籤:深度學習深度學習python演算法pytorch資料探勘
目錄
公式
C
r
o
s
s
E
n
t
r
o
p
y
=
−
∑
k
=
1
N
(
p
k
⋅
log
q
k
)
CrossEntropy=-\sum_{k=1}^{N}(p_k \cdot \log{q_k})
CrossEntropy=−∑k=1N(pk⋅logqk)
其中
q
i
q_i
qi 為模型預測值,而且是經過了softmax層後的結果。
對於分類任務
有唯一
p
k
p_k
C
r
o
s
s
E
n
t
r
o
p
y
=
−
log
q
k
CrossEntropy=-\log{q_k}
CrossEntropy=−logqk
pytorch中使用torch.nn.CrossEntropyLoss()
- 見官網 torch.nn.CrossEntropyLoss()
- input:模型的一組輸出,大小為(batch_size, C),C為分類任務的類別數;即,input[] 有 batch_size
條,每一條大小為 ( C ),C個值分別代表該條資料屬於C個類的概率。 - target:這組資料對應的一組真實label,大小(batch_size) ;target只能為一維資料,每個值取 [0,C)中的整數,整數代表類別不表示大小,整數之間關係獨立(類似 one-hot 編碼);如 target[i]=3 代表 input 中的第 i 條資料屬於第 3 類。
實現:
torch.nn.CrossEntropyLoss() 是 nn.LogSoftmax() + nn.NLLLoss()
解釋:
LogSoftmax()
:就是log(softmax())
;
NLLLoss()
:函式全程是 n e g a t i v e l o g l i k e l i h o o d l o s s negative\ log\ likelihood\ loss negativeloglikelihoodloss,
函式表示式為 f ( x , c l a s s ) = − x [ c l a s s ] f(x,class)=-x[class]f(x,class)=−x[class] ,如 f ( [ 10 , 20 , 30 ] , 2 ) = − 30 f([10,20,30],2)=-30 f([10,20,30],2)=−30;
對於公式
C
r
o
s
s
E
n
t
r
o
p
y
=
−
∑
k
=
1
N
(
p
k
⋅
log
q
k
)
=
−
log
q
k
CrossEntropy=-\sum_{k=1}^{N}(p_k \cdot \log{q_k})=-\log{q_k}
CrossEntropy=−∑k=1N(pk⋅logqk)=−logqk,
q
k
{q_k}
qk 是 softmax 後的結果,
log
q
k
\log{q_k}
logqk 就代表 log(softmax())
,而第二個等號就類似 NLLLoss() 選擇。
所以,如果輸入的target使用了多維資料,控制檯很可能報 NLLLoss() 出錯。
對於分佈任務
建議先閱讀上面的分類任務解釋
torch.nn.CrossEntropyLoss() = nn.LogSoftmax() + nn.NLLLoss()
公式展開為
C
r
o
s
s
E
n
t
r
o
p
y
=
−
p
1
log
q
1
−
p
2
log
q
2
−
…
−
p
n
log
q
n
CrossEntropy=-p_1\log{q_1}-p_2\log{q_2}-\ldots -p_n\log{q_n}
CrossEntropy=−p1logq1−p2logq2−…−pnlogqn
由公式可知,我們仍然需要 LogSoftmax() 的結果,但不需要 nn.NLLLoss(),因此新寫一個 loss 函式;
def myCrossEntropyLoss(self, outputs, labels):
batch_size = outputs.size()[0]
# 對每一行資料進行 logsoftmax 並取負,即公式中的 -log(q)
outputs = -F.log_softmax(outputs, dim=1)
# logsoftmax 後的資料和 label 的概率點乘,即公式中的 -p*log(q)
res = outputs * labels
# 求平均loss,這裡7是我訓練的類別數
return torch.sum(res)/(batch_size*7)
參考
Cross-entropy for classification 【 !推薦閱讀 !】
Pytorch裡的CrossEntropyLoss詳解【 !推薦閱讀 !】
手寫CrossEntropyLoss【 !推薦閱讀 !】
cross-entropy-loss-explanation
pytorch官網log_softmax()
pytorch官網torch.nn.CrossEntropyLoss()
logsoftmax 和 softmax 的區別
Pytorch中Softmax和LogSoftmax的使用
tensor乘法
tensor求平均