1. 程式人生 > 實用技巧 >torch 中的損失函式

torch 中的損失函式

  1. NLLLoss 和 CrossEntropyLoss
    在圖片單標籤分類時,輸入m張圖片,輸出一個mN的Tensor,其中N是分類個數。比如輸入3張圖片,分三類,最後的輸出是一個33的Tensor
input = torch.tensor([[-0.1123, -0.6028, -0.0450],
              [ 0.1596,  0.2215, -1.0176],
              [-0.2359, -0.7898,  0.7097]])

第123行分別是第123張圖片的結果,假設第123列分別是貓、狗和豬的分類得分。
first step: 對每一行使用Softmax,這樣可以得到每張圖片的概率分佈。概率最大的為:1:豬;2:狗;3:豬。

sm = torch.nn.Softmax(dim=1)
sm(input)
tensor([[0.3729, 0.2283, 0.3988],
        [0.4216, 0.4485, 0.1299],
        [0.2410, 0.1385, 0.6205]])

second step: 對softmax結果取對數

torch.log(sm(input))
tensor([[-0.9865, -1.4770, -0.9192],
        [-0.8637, -0.8019, -2.0409],
        [-1.4229, -1.9767, -0.4773]])

Softmax後的數值都在0~1之間,所以log之後值域是負無窮到0。
NLLLoss的結果就是把上面的輸出與Label對應的那個值拿出來,再去掉負號,再求均值。
假設我們現在Target是[0,2,1](第一張圖片是貓,第二張是豬,第三張是狗)。第一行取第0個元素,第二行取第2個,第三行取第1個,去掉負號,結果是:[0.9865,2.0409,1.9767]。再求個均值,結果是:1.66
對比NLLLoss的結果

loss(torch.log(sm(input)),target)
# 1.6681

CrossEntropyLoss 相當於上述步驟的組合,Softmax–Log–NLLLoss合併成一步

loss2 = torch.nn.CrossEntropyLoss()
loss2(input,target)
# 1.6681