PyTorch筆記--交叉熵損失函式實現
交叉熵(cross entropy):用於度量兩個概率分佈間的差異資訊。交叉熵越小,代表這兩個分佈越接近。
函式表示(這是使用softmax作為啟用函式的損失函式表示):
(是真實值,是預測值。)
命名說明:
pred=F.softmax(logits),logits是softmax函式的輸入,pred代表預測值,是softmax函式的輸出。
pred_log=F.log_softmax(logits),pred_log代表對預測值再取對數後的結果。也就是將logits作為log_softmax()函式的輸入。
方法一,使用log_softmax()+nll_loss()實現
torch.nn.functional.log_softmax
對輸入使用softmax函式計算,再取對數。
torch.nn.functional.nll_loss(input, target)
input是經log_softmax()函式處理後的結果,pred_log
target代表的是真實值。
有了這兩個輸入後,該函式對其實現交叉熵損失函式的計算,即上面公式中的L。
>>> import torch >>> import torch.nn.functional as F >>> x = torch.randn(1, 28) >>> w = torch.randn(10,28) >>> logits = x @ w.t() >>> pred_log = F.log_softmax(logits, dim=1) >>> pred_log tensor([[ -0.8779, -6.7271, -9.1801, -6.8515, -9.6900, -6.3061, -3.7304, -8.1933, -11.5704, -0.5873]]) >>> F.nll_loss(pred_log, torch.tensor([3])) tensor(6.8515)
logits的維度是(1, 10)這裡可以理解成是1個輸入,最終可能得到10個分類的結果中的一個。pred_log就是。
這裡的引數target=torch.tensor([3]),我的理解是,他代表真正的分類的值是在第3類(從0編號)。
使用獨熱編碼代表真實值是[0, 0, 0, 1, 0, 0, 0, 0, 0, 0],即這個輸入它是屬於第三類的。
根據上述公式進行計算,現在我們和都已經知道了。
對其進行點乘操作
方法二,使用cross_entropy()實現
torch.nn.functional.cross_entropy(input, target)
這裡的input是沒有經過處理的logits,這個函式會自動根據logits計算出pred_log
target是真實值
>>> import torch >>> import torch.nn.functional as F >>> x = torch.randn(1, 28) >>> w = torch.randn(10,28) >>> logits = x @ w.t() >>> F.cross_entropy(logits, torch.tensor([3])) tensor(6.8515)
這裡我刪除了上面使用方法一的程式碼部分,x和w沒有重新隨機生成,所以計算結果是一樣的。
還在學習過程,做此紀錄,如有不對,請指正。