1. 程式人生 > 其它 >PyTorch筆記--交叉熵損失函式實現

PyTorch筆記--交叉熵損失函式實現

交叉熵(cross entropy):用於度量兩個概率分佈間的差異資訊。交叉熵越小,代表這兩個分佈越接近。

函式表示(這是使用softmax作為啟用函式的損失函式表示):

(是真實值,是預測值。)

命名說明:

pred=F.softmax(logits),logits是softmax函式的輸入,pred代表預測值,是softmax函式的輸出。

pred_log=F.log_softmax(logits),pred_log代表對預測值再取對數後的結果。也就是將logits作為log_softmax()函式的輸入。

方法一,使用log_softmax()+nll_loss()實現

torch.nn.functional.log_softmax

(input)

  對輸入使用softmax函式計算,再取對數。

torch.nn.functional.nll_loss(input, target)

  input是經log_softmax()函式處理後的結果,pred_log

  target代表的是真實值。

  有了這兩個輸入後,該函式對其實現交叉熵損失函式的計算,即上面公式中的L。

>>> import torch
>>> import torch.nn.functional as F
>>> x = torch.randn(1, 28)
>>> w = torch.randn(10
,28) >>> logits = x @ w.t() >>> pred_log = F.log_softmax(logits, dim=1) >>> pred_log tensor([[ -0.8779, -6.7271, -9.1801, -6.8515, -9.6900, -6.3061, -3.7304, -8.1933, -11.5704, -0.5873]]) >>> F.nll_loss(pred_log, torch.tensor([3])) tensor(6.8515)

logits的維度是(1, 10)這裡可以理解成是1個輸入,最終可能得到10個分類的結果中的一個。pred_log就是。

這裡的引數target=torch.tensor([3]),我的理解是,他代表真正的分類的值是在第3類(從0編號)。

使用獨熱編碼代表真實值是[0, 0, 0, 1, 0, 0, 0, 0, 0, 0],即這個輸入它是屬於第三類的。

根據上述公式進行計算,現在我們和都已經知道了。

對其進行點乘操作

方法二,使用cross_entropy()實現

torch.nn.functional.cross_entropy(input, target)

  這裡的input是沒有經過處理的logits,這個函式會自動根據logits計算出pred_log

  target是真實值

>>> import torch
>>> import torch.nn.functional as F
>>> x = torch.randn(1, 28)
>>> w = torch.randn(10,28)
>>> logits = x @ w.t()
>>> F.cross_entropy(logits, torch.tensor([3]))
tensor(6.8515)

這裡我刪除了上面使用方法一的程式碼部分,x和w沒有重新隨機生成,所以計算結果是一樣的。

還在學習過程,做此紀錄,如有不對,請指正。