1. 程式人生 > 程式設計 >Pytorch技巧:DataLoader的collate_fn引數使用詳解

Pytorch技巧:DataLoader的collate_fn引數使用詳解

DataLoader完整的引數表如下:

class torch.utils.data.DataLoader(
 dataset,batch_size=1,shuffle=False,sampler=None,batch_sampler=None,num_workers=0,collate_fn=<function default_collate>,pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None)

DataLoader在資料集上提供單程序或多程序的迭代器

幾個關鍵的引數意思:

- shuffle:設定為True的時候,每個世代都會打亂資料集

- collate_fn:如何取樣本的,我們可以定義自己的函式來準確地實現想要的功能

- drop_last:告訴如何處理資料集長度除於batch_size餘下的資料。True就拋棄,否則保留

一個測試的例子

import torch
import torch.utils.data as Data
import numpy as np

test = np.array([0,1,2,3,4,5,6,7,8,9,10,11])

inputing = torch.tensor(np.array([test[i:i + 3] for i in range(10)]))
target = torch.tensor(np.array([test[i:i + 1] for i in range(10)]))

torch_dataset = Data.TensorDataset(inputing,target)
batch = 3

loader = Data.DataLoader(
 dataset=torch_dataset,batch_size=batch,# 批大小
 # 若dataset中的樣本數不能被batch_size整除的話,最後剩餘多少就使用多少
 collate_fn=lambda x:(
  torch.cat(
   [x[i][j].unsqueeze(0) for i in range(len(x))],0
   ).unsqueeze(0) for j in range(len(x[0]))
  )
 )

for (i,j) in loader:
 print(i)
 print(j)

輸出結果:

tensor([[[ 0,2],[ 1,3],[ 2,4]]],dtype=torch.int32)
tensor([[[ 0],[ 1],[ 2]]],dtype=torch.int32)
tensor([[[ 3,5],[ 4,6],[ 5,7]]],dtype=torch.int32)
tensor([[[ 3],[ 4],[ 5]]],dtype=torch.int32)
tensor([[[ 6,8],[ 7,9],[ 8,10]]],dtype=torch.int32)
tensor([[[ 6],[ 7],[ 8]]],dtype=torch.int32)
tensor([[[ 9,11]]],dtype=torch.int32)
tensor([[[ 9]]],dtype=torch.int32)

如果不要collate_fn的值,輸出變成

tensor([[ 0,4]],dtype=torch.int32)
tensor([[ 0],[ 2]],dtype=torch.int32)
tensor([[ 3,7]],dtype=torch.int32)
tensor([[ 3],[ 5]],dtype=torch.int32)
tensor([[ 6,10]],dtype=torch.int32)
tensor([[ 6],[ 8]],dtype=torch.int32)
tensor([[ 9,11]],dtype=torch.int32)
tensor([[ 9]],dtype=torch.int32)

所以collate_fn就是使結果多一維。

看看collate_fn的值是什麼意思。我們把它改為如下

collate_fn=lambda x:x

並輸出

for i in loader:
 print(i)

得到結果

[(tensor([ 0,dtype=torch.int32),tensor([ 0],dtype=torch.int32)),(tensor([ 1,tensor([ 1],(tensor([ 2,4],tensor([ 2],dtype=torch.int32))]
[(tensor([ 3,tensor([ 3],(tensor([ 4,tensor([ 4],(tensor([ 5,7],tensor([ 5],dtype=torch.int32))]
[(tensor([ 6,tensor([ 6],(tensor([ 7,tensor([ 7],(tensor([ 8,10],tensor([ 8],dtype=torch.int32))]
[(tensor([ 9,11],tensor([ 9],dtype=torch.int32))]

每個i都是一個列表,每個列表包含batch_size個元組,每個元組包含TensorDataset的單獨資料。所以要將重新組合成每個batch包含1*3*3的input和1*3*1的target,就要重新解包並打包。 看看我們的collate_fn:

collate_fn=lambda x:(
 torch.cat(
  [x[i][j].unsqueeze(0) for i in range(len(x))],0
  ).unsqueeze(0) for j in range(len(x[0]))
 )

j取的是兩個變數:input和target。i取的是batch_size。然後通過unsqueeze(0)方法在前面加一維。torch.cat(,0)將其打包起來。然後再通過unsqueeze(0)方法在前面加一維。 完成。

以上這篇Pytorch技巧:DataLoader的collate_fn引數使用詳解就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。