Pytorch資料讀取(Dataset, DataLoader, DataLoaderIter)
阿新 • • 發佈:2018-12-19
Pytorch的資料讀取主要包含三個類:
- Dataset
- DataLoader
- DataLoaderIter
這三者大致是一個依次封裝的關係: 1.被裝進2., 2.被裝進3.
一. torch.utils.data.Dataset
是一個抽象類, 自定義的Dataset需要繼承它並且實現兩個成員方法:
__getitem__()
__len__()
第一個最為重要, 即每次怎麼讀資料. 以圖片為例:
def __getitem__(self, index): img_path, label = self.data[index].img_path, self.data[index].label img = Image.open(img_path) return img, label
值得一提的是, pytorch還提供了很多常用的transform, 在torchvision.transforms
裡面, 本文中不多介紹, 常用的有Resize
, RandomCrop
, Normalize
, ToTensor
(這個極為重要, 可以把一個PIL或numpy圖片轉為torch.Tensor
, 但是好像對numpy陣列的轉換比較受限, 所以這裡建議在__getitem__()
裡面用PIL來讀圖片, 而不是用skimage.io).
第二個比較簡單, 就是返回整個資料集的長度:
def __len__(self): return len(self.data)
二. torch.utils.data.DataLoader
類定義為:
class torch.utils.data.DataLoader(
dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=<function default_collate>,
pin_memory=False,
drop_last=False
)
可以看到, 主要引數有這麼幾個:
dataset
: 即上面自定義的dataset.collate_fn
: 這個函式用來打包batchnum_worker
: 非常簡單的多執行緒方法, 只要設定為>=1, 就可以多執行緒預讀資料
這個類其實就是下面將要講的DataLoaderIter
的一個框架, 一共幹了兩件事:
- 定義了一堆成員變數, 到時候賦給
DataLoaderIter
, - 然後有一個
__iter__()
函式, 把自己 "裝進"DataLoaderIter
裡面.
def __iter__(self):
return DataLoaderIter(self)
三. torch.utils.data.dataloader.DataLoaderIter
上面提到, DataLoader
就是DataLoaderIter
的一個框架, 用來傳給DataLoaderIter
一堆引數, 並把自己裝進DataLoaderIter
裡。其實到這裡就可以滿足大多數訓練的需求了, 比如
class CustomDataset(Dataset):
# 自定義自己的dataset
dataset = CustomDataset()
dataloader = Dataloader(dataset, ...)
for data in dataloader:
# training...
在for 迴圈裡, 總共有三點操作:
- 呼叫了
dataloader
的__iter__()
方法, 產生了一個DataLoaderIter
- 反覆呼叫
DataLoaderIter
的__next__()
來得到batch, 具體操作就是, 多次呼叫dataset的__getitem__()
方法 (如果num_worker
>0就多執行緒呼叫), 然後用collate_fn
來把它們打包成batch. 中間還會涉及到shuffle
, 以及sample
的方法等. - 當資料讀完後,
__next__()
丟擲一個StopIteration
異常,for
迴圈結束,dataloader
失效.
四. 又一層封裝
其實上面三個類已經可以搞定了, 僅供參考
class DataProvider:
def __init__(self, batch_size, is_cuda):
self.batch_size = batch_size
self.dataset = Dataset_triple(self.batch_size,
transform_=transforms.Compose(
[transforms.Scale([224, 224]),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])]),
)
self.is_cuda = is_cuda # 是否將batch放到gpu上
self.dataiter = None
self.iteration = 0 # 當前epoch的batch數
self.epoch = 0 # 統計訓練了多少個epoch
def build(self):
dataloader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True, num_workers=0, drop_last=True)
self.dataiter = DataLoaderIter(dataloader)
def next(self):
if self.dataiter is None:
self.build()
try:
batch = self.dataiter.next()
self.iteration += 1
if self.is_cuda:
batch = [batch[0].cuda(), batch[1].cuda(), batch[2].cuda()]
return batch
except StopIteration: # 一個epoch結束後reload
self.epoch += 1
self.build()
self.iteration = 1 # reset and return the 1st batch
batch = self.dataiter.next()
if self.is_cuda:
batch = [batch[0].cuda(), batch[1].cuda(), batch[2].cuda()]
return batch
感謝以下連結提供的參考: