1. 程式人生 > 程式設計 >pytorch標籤轉onehot形式例項

pytorch標籤轉onehot形式例項

程式碼:

import torch

class_num = 10
batch_size = 4
label = torch.LongTensor(batch_size,1).random_() % class_num
print(label.size())

one_hot = torch.zeros(batch_size,class_num).scatter_(1,label,1)
print(one_hot)

輸出:

torch.Size([4,1])
tensor([[0.,0.,1.,0.],[0.,1.]])

注意:

label的形狀必須是[n,1]的,也就是必須是二維的,且第二個維度長度為1,如果是一維度的,則需要升維度,程式碼如下:

import torch

class_num = 10
batch_size = 4
label = torch.LongTensor(batch_size).random_() % class_num
print(label.size())
label = torch.unsqueeze(label,dim=1)
print(label.size())

以上這篇pytorch標籤轉onehot形式例項就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。