pytorch標籤轉onehot形式例項
阿新 • • 發佈:2020-01-09
程式碼:
import torch class_num = 10 batch_size = 4 label = torch.LongTensor(batch_size,1).random_() % class_num print(label.size()) one_hot = torch.zeros(batch_size,class_num).scatter_(1,label,1) print(one_hot)
輸出:
torch.Size([4,1]) tensor([[0.,0.,1.,0.],[0.,1.]])
注意:
label的形狀必須是[n,1]的,也就是必須是二維的,且第二個維度長度為1,如果是一維度的,則需要升維度,程式碼如下:
import torch class_num = 10 batch_size = 4 label = torch.LongTensor(batch_size).random_() % class_num print(label.size()) label = torch.unsqueeze(label,dim=1) print(label.size())
以上這篇pytorch標籤轉onehot形式例項就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。