pytorch讀取資料(Dataset, DataLoader, DataLoaderIter)
阿新 • • 發佈:2021-02-18
pytorch資料讀取
參考資料:
pytorch資料讀取
pytorch對nlp資料的處理部落格(以短文字匹配為例)
dataloader使用教程部落格
pytorch使用DataLoader對資料集進行批處理簡單示例
Pytorch的資料讀取主要包含三個類:
- Dataset
- DataLoader
- DataLoaderIter
這三者是依次封裝的關係,Dataset被裝進DataLoader,DataLoder被裝進DataLoaderIter。
Dataloader的處理邏輯是先通過Dataset類裡面的__getitem__函式獲取單個的資料,然後組合成batch,再使用collate_fn所指定的函式對這個batch做一些操作,比如padding。
torch.utils.data.Dataset
是一個抽象類,自定義的Dataset需要繼承它並實現兩個成員方法
__getitem__()
:從資料集中得到一個數據片段(如:資料,標籤)__len__()
:返回整個資料集的長度
自定義Dataset基本的框架是:
class CustomDataset(data.Dataset):#需要繼承data.Dataset
def __init__(self):
# TODO
# 1. Initialize file path or list of file names.
pass
def __getitem__(self, index):
# TODO
# 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
# 2. Preprocess the data (e.g. torchvision.Transform).
# 3. Return a data pair (e.g. image and label).
#這裡需要注意的是,第一步:read one data,是一個data
pass
def __len__(self):
# You should change 0 to the total size of your dataset.
return 0
以短文字匹配為例,判斷兩個文字是否相似,資料格式為(句子1,句子2,是否相似標籤0或1,示例如下圖所示:
對於這樣的資料集,我們構建的Dataset類可以是:
class Dataset(torch.utils.data.Dataset):
def __init__(self, texta, textb, label):
self.texta = texta
self.textb = textb
self.label = label
def __getitem__(self, item):
return self.texta[item], self.textb[item], self.label[item]
def __len__(self):
return len(self.texta)
接下來可以向自己實現的Dataset類中傳值:
train_data = Dataset(train_texta, train_textb, train_label)
torch.utils.data.DataLoader
類的定義為:
class torch.utils.data.DataLoader(
dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=<function default_collate>,
pin_memory=False,
drop_last=False
)
主要的引數有:
- dataset:即我們自定義的Dataset
- collate_fn:這個函式用來打包batch。定義如何把一批dataset的例項轉換為包含mini-batch資料的張量。可以通過自定義collate_fn=myfunction來設計資料打包的方式,通常在myfunction函式中做padding,將同一個batch中不一樣長的句子補全成一樣的長度。
- num_workers:設定>=1時,可多執行緒讀資料。設定=0時,單執行緒讀資料
- shuffle:代表是否要打亂資料,一般對於訓練集資料都是要打亂的,驗證集可以打亂也可以不打亂,但是測試集資料不打亂。
這個類是DataLoaderIter的一個框架,一共做了兩件事情:
- 定義成員變數,到時候賦給DataLoaderIter
- 有一個
__iter__()
函式,把自己裝進DataLoaderIter中
def __iter__(self):
return DataLoaderIter(self)
仍然以短文字匹配任務為例,上面已經得到了train_data,然後就可以使用DataLoader進行處理,DataLoader每次返回一個batch的資料。
train_iter = DataLoader(dataset=train_data, batch_size=args.batch_size, shuffle=True, collate_fn=myfunction)
這裡的myfunction可以定義為:
def myfunction(batch_data, pad=0):
texta,textb,label = list(zip(*batch_data))#batch_data的結構是[([texta_1],[textb_1],[label_1]),([texta_2],[textb_2],[label_2]),...],所以需要使用zip函式對它解壓
max_len_a = max([len(seq_a) for seq_a in texta])
max_len_b = max([len(seq_b) for seq_b in textb])
max_len = max(max_len_a,max_len_b) #這裡我使用的是一個batch中text_a或者是text_b的最大長度作為max_len,也可以自定義長度
texta = [seq+[pad]*(max_len-len(seq)) for seq in texta]
textb = [seq+[pad]*(max_len-len(seq)) for seq in textb]
texta = torch.LongTensor(texta)
textb = torch.LongTensor(textb)
label = torch.FloatTensor(label)
return (texta,textb,label)
torch.utils.data.dataloader.DataLoaderIter
這裡其實是迭代呼叫DataLoader的過程。一般載入資料的整個流程為:
class Dataset(Dataset)
# 自定義的Dataset
# 返回(data, label)
dataset = Dataset()
dataloader = DataLoader(dataset, ...)
for data in dataloader:
# training
在for迴圈中,總共有三個操作:
- 呼叫DataLoder的
__iter__()
方法,產生一個DataLoaderIter - 反覆呼叫DataLoaderIter 的__next__()來得到batch, 具體操作就是, 多次呼叫dataset的__getitem__()方法 (如果num_worker>0就多執行緒呼叫), 然後用collate_fn來把它們打包成batch
- 當資料讀完後, next()丟擲一個StopIteration異常, for迴圈結束, dataloader 失效
以上的短文字匹配的例子中,for迴圈部分可以寫成:
for batch_data in tqdm(train_iter):
texta, textb, tag = map(lambda x: x.to(device), batch_data)
output = model(texta, textb)