1. 程式人生 > 其它 >Pytorch nn.BCEWithLogitsLoss() 的簡單理解與用法

Pytorch nn.BCEWithLogitsLoss() 的簡單理解與用法

這個東西,本質上和nn.BCELoss()沒有區別,只是在BCELoss上加了個logits函式(也就是sigmoid函式),例子如下:

import torch
import torch.nn as nn

label = torch.Tensor([1, 1, 0])
pred = torch.Tensor([3, 2, 1])
pred_sig = torch.sigmoid(pred)
loss = nn.BCELoss()
print(loss(pred_sig, label))

loss = nn.BCEWithLogitsLoss()
print(loss(pred, label))

loss 
= nn.BCEWithLogitsLoss() print(loss(pred_sig, label))

輸出結果分別為:

tensor(0.4963)
tensor(0.4963)
tensor(0.5990)

可以看到,nn.BCEWithLogitsLoss()相當於是在nn.BCELoss()中預測結果pred的基礎上先做了個sigmoid,然後繼續正常算loss。所以這就涉及到一個比較奇葩的bug,如果網路本身在輸出結果的時候已經用sigmoid去處理了,算loss的時候用nn.BCEWithLogitsLoss()…那麼就會相當於預測結果算了兩次sigmoid,可能會出現各種奇奇怪怪的問題——



因上求緣,果上努力~~~~ 作者:Learner-,轉載請註明原文連結:https://www.cnblogs.com/BlairGrowing/p/15970528.html