Pytorch技巧:DataLoader的collate_fn引數使用詳解
阿新 • • 發佈:2020-01-09
DataLoader完整的引數表如下:
class torch.utils.data.DataLoader( dataset,batch_size=1,shuffle=False,sampler=None,batch_sampler=None,num_workers=0,collate_fn=<function default_collate>,pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None)
DataLoader在資料集上提供單程序或多程序的迭代器
幾個關鍵的引數意思:
- shuffle:設定為True的時候,每個世代都會打亂資料集
- collate_fn:如何取樣本的,我們可以定義自己的函式來準確地實現想要的功能
- drop_last:告訴如何處理資料集長度除於batch_size餘下的資料。True就拋棄,否則保留
一個測試的例子
import torch import torch.utils.data as Data import numpy as np test = np.array([0,1,2,3,4,5,6,7,8,9,10,11]) inputing = torch.tensor(np.array([test[i:i + 3] for i in range(10)])) target = torch.tensor(np.array([test[i:i + 1] for i in range(10)])) torch_dataset = Data.TensorDataset(inputing,target) batch = 3 loader = Data.DataLoader( dataset=torch_dataset,batch_size=batch,# 批大小 # 若dataset中的樣本數不能被batch_size整除的話,最後剩餘多少就使用多少 collate_fn=lambda x:( torch.cat( [x[i][j].unsqueeze(0) for i in range(len(x))],0 ).unsqueeze(0) for j in range(len(x[0])) ) ) for (i,j) in loader: print(i) print(j)
輸出結果:
tensor([[[ 0,2],[ 1,3],[ 2,4]]],dtype=torch.int32) tensor([[[ 0],[ 1],[ 2]]],dtype=torch.int32) tensor([[[ 3,5],[ 4,6],[ 5,7]]],dtype=torch.int32) tensor([[[ 3],[ 4],[ 5]]],dtype=torch.int32) tensor([[[ 6,8],[ 7,9],[ 8,10]]],dtype=torch.int32) tensor([[[ 6],[ 7],[ 8]]],dtype=torch.int32) tensor([[[ 9,11]]],dtype=torch.int32) tensor([[[ 9]]],dtype=torch.int32)
如果不要collate_fn的值,輸出變成
tensor([[ 0,4]],dtype=torch.int32) tensor([[ 0],[ 2]],dtype=torch.int32) tensor([[ 3,7]],dtype=torch.int32) tensor([[ 3],[ 5]],dtype=torch.int32) tensor([[ 6,10]],dtype=torch.int32) tensor([[ 6],[ 8]],dtype=torch.int32) tensor([[ 9,11]],dtype=torch.int32) tensor([[ 9]],dtype=torch.int32)
所以collate_fn就是使結果多一維。
看看collate_fn的值是什麼意思。我們把它改為如下
collate_fn=lambda x:x
並輸出
for i in loader: print(i)
得到結果
[(tensor([ 0,dtype=torch.int32),tensor([ 0],dtype=torch.int32)),(tensor([ 1,tensor([ 1],(tensor([ 2,4],tensor([ 2],dtype=torch.int32))] [(tensor([ 3,tensor([ 3],(tensor([ 4,tensor([ 4],(tensor([ 5,7],tensor([ 5],dtype=torch.int32))] [(tensor([ 6,tensor([ 6],(tensor([ 7,tensor([ 7],(tensor([ 8,10],tensor([ 8],dtype=torch.int32))] [(tensor([ 9,11],tensor([ 9],dtype=torch.int32))]
每個i都是一個列表,每個列表包含batch_size個元組,每個元組包含TensorDataset的單獨資料。所以要將重新組合成每個batch包含1*3*3的input和1*3*1的target,就要重新解包並打包。 看看我們的collate_fn:
collate_fn=lambda x:( torch.cat( [x[i][j].unsqueeze(0) for i in range(len(x))],0 ).unsqueeze(0) for j in range(len(x[0])) )
j取的是兩個變數:input和target。i取的是batch_size。然後通過unsqueeze(0)方法在前面加一維。torch.cat(,0)將其打包起來。然後再通過unsqueeze(0)方法在前面加一維。 完成。
以上這篇Pytorch技巧:DataLoader的collate_fn引數使用詳解就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。