1. 程式人生 > 其它 >pytorch中nn.CrossEntropyLoss使用注意事項

pytorch中nn.CrossEntropyLoss使用注意事項

技術標籤:程式碼試錯

Loss的數學表達公式:

使用程式碼樣例:

# 這樣展開就相當於每個詞正確的類別和預測的整個詞表概率分佈進行對應
# ignore_index是指忽略真實標籤中的類別
criterion = nn.CrossEntropyLoss(ignore_index=2).to(device) 
vocab_size = pre.shape[-1]
trg = trg[:,1:]
trg_tag = trg.reshape(-1).to(device) # view函式要求在同一個連續地址裡,而reshape不用
pre_tag = pre[1:].view(-1,vocab_size).to(device)
loss = criterion(pre_tag,trg_tag)

注意事項:

  1. CrossEntropyLoss例項化之後,其兩個輸入分別是預測標籤和真實標籤,順序不要搞錯。預測標籤的大小為[N,classnum],真實樣本的大小為[N],因為該函式會把真實標籤進行one-hot表示。N不一定是batchsize大小,可以對向量進行展開,從而可以逐個樣本進行計算loss。

  2. 從公式可以看出,pytorch中的交叉熵loss其本身已使用的一個softmax約束了預測標籤輸入控制在了0-1之間,所以loss的輸入即用模型的輸出即可不需要通過softmax後再輸入loss中,否則兩個softmax可能會導致模型在訓練的過程中loss保持不變。