1. 程式人生 > 程式設計 >Pytorch DataLoader 變長資料處理方式

Pytorch DataLoader 變長資料處理方式

關於Pytorch中怎麼自定義Dataset資料集類、怎樣使用DataLoader迭代載入資料,這篇官方文件已經說得很清楚了,這裡就不在贅述。

現在的問題:有的時候,特別對於NLP任務來說,輸入的資料可能不是定長的,比如多個句子的長度一般不會一致,這時候使用DataLoader載入資料時,不定長的句子會被胡亂切分,這肯定是不行的。

解決方法是重寫DataLoader的collate_fn,具體方法如下:

# 假如每一個樣本為:
sample = {
	# 一個句子中各個詞的id
	'token_list' : [5,2,4,1,9,8],# 結果y
	'label' : 5,}


# 重寫collate_fn函式,其輸入為一個batch的sample資料
def collate_fn(batch):
	# 因為token_list是一個變長的資料,所以需要用一個list來裝這個batch的token_list
  token_lists = [item['token_list'] for item in batch]
  
  # 每個label是一個int,我們把這個batch中的label也全取出來,重新組裝
  labels = [item['label'] for item in batch]
  # 把labels轉換成Tensor
  labels = torch.Tensor(labels)
  return {
    'token_list': token_lists,'label': labels,}


# 在使用DataLoader載入資料時,注意collate_fn引數傳入的是重寫的函式
DataLoader(trainset,batch_size=4,shuffle=True,num_workers=4,collate_fn=collate_fn)

使用以上方法,可以保證DataLoader能Load出一個batch的資料,load出來的東西就是重寫的collate_fn函式最後return出來的字典。

以上這篇Pytorch DataLoader 變長資料處理方式就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。