pytorch中二分類的lossf函式使用
阿新 • • 發佈:2021-01-02
多分類不需要滿足y_pred和y_gt的維度相同
二分類中,torch.nn.BECLoss
和torch,nn.BCEWithLogitsLoss
需要預測輸出y_pred和y_gt需要一樣的維度。
故需要使用一下兩種轉變y_gt的方法。
假設y_pred如下:
[[1.2 2.3 ],
[2.1 2.2 ]]
其中,y_gt如下:
[[0],[1]]
目標是是轉化成下面的維度:
[[1,0],
[0,1]]
程式碼如下:
- 直接使用list的index屬性。
gt_y_temp = torch.zeros(gt_y.shape[0], 2)
gt_y_temp[ range(gt_y.shape[0]), list(gt_y.squeeze(1).int())] = 1
gt_y_temp=gt_y_temp.cuda()
- 使用scatter來使用
gt_y_temp = torch.zeros(gt_y.shape[0], 2, device='cuda').scatter_(1, gt_y.long(), torch.tensor(1, dtype=torch.float)).cuda()
注意這個地方gt_y的特徵維度最後一個維度是1.比如:[[0],[1],[0]]
更多案例
>>>class_num = 10
>>> batch_size = 4
>>>label = torch.LongTensor(batch_size, 1).random_() % class_num
3
0
0
8
>>>one_hot = torch.zeros(batch_size, class_num).scatter_(1, label, 1)
0 0 0 1 0 0 0 0 0 0
1 0 0 0 0 0 0 0 0 0
1 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 1 0
參考部落格:
pytorch中scatter的使用原理
三種one_hot的辦法
scatter使用方法案例
How is Pytorch’s binary_cross_entropy_with_logits function related to sigmoid and binary_cross_entropy