1. 程式人生 > 其它 >pytorch中二分類的lossf函式使用

pytorch中二分類的lossf函式使用

技術標籤:深度學習技術棧

多分類不需要滿足y_pred和y_gt的維度相同
在這裡插入圖片描述
二分類中,torch.nn.BECLosstorch,nn.BCEWithLogitsLoss需要預測輸出y_pred和y_gt需要一樣的維度
故需要使用一下兩種轉變y_gt的方法。
假設y_pred如下:

[[1.2 2.3 ],
[2.1 2.2 ]]

其中,y_gt如下:

[[0],[1]]

目標是是轉化成下面的維度:

[[1,0],
[0,1]]

程式碼如下:

  1. 直接使用list的index屬性。
gt_y_temp = torch.zeros(gt_y.shape[0], 2)
gt_y_temp[
range(gt_y.shape[0]), list(gt_y.squeeze(1).int())] = 1 gt_y_temp=gt_y_temp.cuda()
  1. 使用scatter來使用
gt_y_temp = torch.zeros(gt_y.shape[0], 2, device='cuda').scatter_(1, gt_y.long(), torch.tensor(1, dtype=torch.float)).cuda()

注意這個地方gt_y的特徵維度最後一個維度是1.比如:[[0],[1],[0]]

更多案例

>>>class_num = 10
>>>
batch_size = 4 >>>label = torch.LongTensor(batch_size, 1).random_() % class_num 3 0 0 8 >>>one_hot = torch.zeros(batch_size, class_num).scatter_(1, label, 1) 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0

參考部落格:
pytorch中scatter的使用原理
三種one_hot的辦法
scatter使用方法案例
How is Pytorch’s binary_cross_entropy_with_logits function related to sigmoid and binary_cross_entropy