1. 程式人生 > 其它 >PyTorch之torch.nn.CrossEntropyLoss()

PyTorch之torch.nn.CrossEntropyLoss()

技術標籤:PyTorchpython深度學習機器學習演算法人工智慧

  1. 簡介
    資訊熵: 按照真實分佈p來衡量識別一個樣本所需的編碼長度的期望,即平均編碼長度
    在這裡插入圖片描述
    交叉熵: 使用擬合分佈q來表示來自真實分佈p的編碼長度的期望,即平均編碼長度
    在這裡插入圖片描述
    多分類任務中的交叉熵損失函式
    在這裡插入圖片描述

  2. 程式碼

1)匯入包

import torch
import torch.nn as nn

2)準備資料
在圖片單標籤分類時,輸入m張圖片,輸出一個m x N的Tensor,其中N是分類個數。比如輸入3張圖片,分三類,最後的輸出是一個3 x 3的Tensor,舉個例子:

x_input=torch.randn(3,3
) print('x_input:\n',x_input) y_target=torch.tensor([1,2,0])

在這裡插入圖片描述
3)計算概率分佈
第123行分別是第123張圖片的結果,假設第123列分別是貓、狗和豬的分類得分。
然後對每一行使用Softmax,這樣可以得到每張圖片的概率分佈。

softmax_func=nn.Softmax(dim=1)
soft_output=softmax_func(x_input)
print('soft_output:\n',soft_output)

在這裡插入圖片描述
這裡dim的意思是計算Softmax的維度,這裡設定dim=1,可以看到每一行的加和為1。比如第一行0.1022+0.3831+0.5147=1。

4)對Softmax的結果取自然對數

log_output=torch.log(soft_output)
print('log_output:\n',log_output)

在這裡插入圖片描述
對比softmax與log的結合與nn.LogSoftmaxloss(負對數似然損失)的輸出結果,兩者是一致的。

logsoftmax_func=nn.LogSoftmax(dim=1)
logsoftmax_output=logsoftmax_func(x_input)
print('logsoftmax_output:\n',logsoftmax_output)

在這裡插入圖片描述
5)NLLLoss
NLLLoss的結果就是把上面的輸出與y_label對應的那個值拿出來,再去掉負號,再求均值。

nllloss_func=nn.NLLLoss()
nlloss_output=nllloss_func(logsoftmax_output,y_target)
print('nlloss_output:\n',nlloss_output)

y_target中[1, 2, 0]對應上述第一行的第二個,第二行的第三個,第三行的第1個:
(0.9594+0.4241+0.5265)/3=0.6367
在這裡插入圖片描述
6) CrossEntropyLoss()

crossentropyloss=nn.CrossEntropyLoss()
crossentropyloss_output=crossentropyloss(x_input,y_target)
print('crossentropyloss_output:\n',crossentropyloss_output)

在這裡插入圖片描述

參考連結:
https://blog.csdn.net/qq_22210253/article/details/85229988
https://zhuanlan.zhihu.com/p/98785902
https://zhuanlan.zhihu.com/p/56638625