1. 程式人生 > >Pytorch資料讀取(Dataset, DataLoader, DataLoaderIter)

Pytorch資料讀取(Dataset, DataLoader, DataLoaderIter)

Pytorch的資料讀取主要包含三個類:

  1. Dataset
  2. DataLoader
  3. DataLoaderIter

這三者大致是一個依次封裝的關係: 1.被裝進2., 2.被裝進3.

一. torch.utils.data.Dataset

是一個抽象類, 自定義的Dataset需要繼承它並且實現兩個成員方法:

  1. __getitem__()
  2. __len__()

第一個最為重要, 即每次怎麼讀資料. 以圖片為例:

def __getitem__(self, index):
        img_path, label = self.data[index].img_path, self.data[index].label
        img = Image.open(img_path)

        return img, label

值得一提的是, pytorch還提供了很多常用的transform, 在torchvision.transforms 裡面, 本文中不多介紹, 常用的有Resize , RandomCrop , Normalize , ToTensor (這個極為重要, 可以把一個PIL或numpy圖片轉為torch.Tensor, 但是好像對numpy陣列的轉換比較受限, 所以這裡建議在__getitem__()裡面用PIL來讀圖片, 而不是用skimage.io).

 第二個比較簡單, 就是返回整個資料集的長度:

def __len__(self):
        return len(self.data)

二. torch.utils.data.DataLoader

類定義為:

class torch.utils.data.DataLoader(
    dataset, 
    batch_size=1, 
    shuffle=False, 
    sampler=None,
    batch_sampler=None, 
    num_workers=0, 
    collate_fn=<function default_collate>, 
    pin_memory=False, 
    drop_last=False
)

可以看到, 主要引數有這麼幾個:

  1. dataset : 即上面自定義的dataset.
  2. collate_fn: 這個函式用來打包batch
  3. num_worker: 非常簡單的多執行緒方法, 只要設定為>=1, 就可以多執行緒預讀資料

這個類其實就是下面將要講的DataLoaderIter的一個框架, 一共幹了兩件事:

  1. 定義了一堆成員變數, 到時候賦給DataLoaderIter,
  2. 然後有一個__iter__() 函式, 把自己 "裝進" DataLoaderIter 裡面.
def __iter__(self):
        return DataLoaderIter(self)

三. torch.utils.data.dataloader.DataLoaderIter

上面提到, DataLoader就是DataLoaderIter的一個框架, 用來傳給DataLoaderIter 一堆引數, 並把自己裝進DataLoaderIter 裡。其實到這裡就可以滿足大多數訓練的需求了, 比如

class CustomDataset(Dataset):
   # 自定義自己的dataset

dataset = CustomDataset()
dataloader = Dataloader(dataset, ...)

for data in dataloader:
   # training...

在for 迴圈裡, 總共有三點操作:

  1. 呼叫了dataloader 的__iter__() 方法, 產生了一個DataLoaderIter
  2. 反覆呼叫DataLoaderIter 的__next__()來得到batch, 具體操作就是, 多次呼叫dataset的__getitem__()方法 (如果num_worker>0就多執行緒呼叫), 然後用collate_fn來把它們打包成batch. 中間還會涉及到shuffle , 以及sample 的方法等.
  3. 當資料讀完後, __next__()丟擲一個StopIteration異常, for迴圈結束, dataloader 失效.

四. 又一層封裝

其實上面三個類已經可以搞定了, 僅供參考

class DataProvider:
    def __init__(self, batch_size, is_cuda):
        self.batch_size = batch_size
        self.dataset = Dataset_triple(self.batch_size,
                                      transform_=transforms.Compose(
                                     [transforms.Scale([224, 224]),
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                           std=[0.229, 0.224, 0.225])]),
                                      )
        self.is_cuda = is_cuda  # 是否將batch放到gpu上
        self.dataiter = None
        self.iteration = 0  # 當前epoch的batch數
        self.epoch = 0  # 統計訓練了多少個epoch

    def build(self):
        dataloader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True, num_workers=0, drop_last=True)
        self.dataiter = DataLoaderIter(dataloader)

    def next(self):
        if self.dataiter is None:
            self.build()
        try:
            batch = self.dataiter.next()
            self.iteration += 1

            if self.is_cuda:
                batch = [batch[0].cuda(), batch[1].cuda(), batch[2].cuda()]
            return batch

        except StopIteration:  # 一個epoch結束後reload
            self.epoch += 1
            self.build()
            self.iteration = 1  # reset and return the 1st batch

            batch = self.dataiter.next()
            if self.is_cuda:
                batch = [batch[0].cuda(), batch[1].cuda(), batch[2].cuda()]
            return batch

感謝以下連結提供的參考: