1. 程式人生 > 其它 >pytorch的dataset與dataloader解析

pytorch的dataset與dataloader解析

整理一下pytorch獲取的流程:

  1. 建立Dataset物件
  2. 建立DataLoader物件,裝載有dataset物件
  3. 迴圈DataLoader物件,DataLoader.__iter__返回的是DataLoaderIter物件
dataset = MyDataset()
dataloader = DataLoader(dataset)
num_epoches = 100
for epoch in range(num_epoches):
    for data in dataloader:
        ....

根據原始碼分析:torch.utils.data

1 - Dataset:

class Dataset(object):
    """An abstract class representing a Dataset.
    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

Dataset這是一個抽象類,不能例項化,需要重寫類方法,關鍵點有兩個:

  • __getitem__ 這個很重要,規定了如何讀資料,比如常用的transform
  • __len__ 這個就是返回資料集的長度,比如:return len(self.data)

2 - DataLoader:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None)

先看一下主要引數:

  • dataset:就是 torch.utils.data.Dataset 類的例項。也就是說為了使用 DataLoader 類,需要先定義一個 torch.utils.data.Dataset 類的例項。
  • batch_size:每一個批次需要載入的訓練樣本個數。
  • shuffle:如果設定為 True 表示訓練樣本資料會被隨機打亂,預設值為 False。一般會設定為 True 。
  • sampler:自定義從資料集中取樣本的策略,如果指定這個引數,那麼 shuffle 必須為 False 。從原始碼中可以看到,如果指定了該引數,同時 shuffle 設定為 True,DataLoader 的 __init__ 函式就會丟擲一個異常 。
  • batch_sampler:與 sampler 類似,但是一次只返回一個 batch 的 indices(索引),需要注意的是,一旦指定了這個引數,那麼 batch_size,shuffle,sampler,drop_last 就不能再指定了。原始碼中同樣做了限制。
  • num_workers:表示會使用多少個執行緒來載入訓練資料;預設值為 0,表示資料載入直接在主執行緒中進行。
  • collate_fn:對每一個 batch 的資料做一些你想要的操作。一個例子,https://zhuanlan.zhihu.com/p/346332974
  • pin_memory:把資料轉移到和 GPU 相關聯的 CPU 記憶體,加速 GPU 載入資料的速度。
  • drop_last:比如你的batch_size設定為 32,而一個 epoch 只有 100 個樣本;如果設定為 True,那麼訓練的時候後面的 4 個就被扔掉了。如果為 False(預設),那麼會繼續正常執行,只是最後的 batch_size 會小一點。
  • timeout:載入一個 batch 資料的超時時間。
  • worker_init_fn:指定每個資料載入執行緒的入口函式。

原始碼分析:

class DataLoader(object):
    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, 
                 batch_sampler=None,
                 num_workers=0, collate_fn=default_collate, pin_memory=False, 
                 drop_last=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.collate_fn = collate_fn
        self.pin_memory = pin_memory
        self.drop_last = drop_last

        if batch_sampler is not None:
            if batch_size > 1 or shuffle or sampler is not None or drop_last:
                raise ValueError('batch_sampler is mutually exclusive with '
                                 'batch_size, shuffle, sampler, and drop_last')

        if sampler is not None and shuffle:
            raise ValueError('sampler is mutually exclusive with shuffle')

        if batch_sampler is None:
            if sampler is None:
                if shuffle:
                    # dataset.__len__() 在 Sampler 中被使用。
                    # 目的是生成一個 長度為 len(dataset) 的 序列索引(隨機的)。
                    sampler = RandomSampler(dataset)
                else:
                    # dataset.__len__() 在 Sampler 中被使用。
                    # 目的是生成一個 長度為 len(dataset) 的 序列索引(順序的)。
                    sampler = SequentialSampler(dataset)
            # Sampler 是個迭代器,一次之只返回一個 索引
            # BatchSampler 也是個迭代器,但是一次返回 batch_size 個 索引
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)

        self.sampler = sampler
        self.batch_sampler = batch_sampler

    def __iter__(self):
        return DataLoaderIter(self)

    def __len__(self):
        return len(self.batch_sampler) 

可以發現__iter__返回的是DataLoaderIter

3 - DataLoaderIter

先看init初始化:

if self.num_workers > 0:
    self.worker_init_fn = loader.worker_init_fn
# 定義了workers相同數量個Queue並放置在index_queues這個list中, # 這些Queue與worker一一對應,用來給worker傳遞“工作內容” self.index_queues = [multiprocessing.Queue() for _ in range(self.num_workers)]
# worker_queue_idx用於下一個工作的workre序號,主程序輪詢使用不同workers self.worker_queue_idx = 0
# 各個workre將自己所取得的資料傳遞給wokrker_result_queue,供主程序fetch self.worker_result_queue = multiprocessing.SimpleQueue() # 記錄當前時刻分配了多少個任務(可能有處於等待狀態的任務) self.batches_outstanding = 0 self.worker_pids_set = False self.shutdown = False # 傳送出去資料的編號 self.send_idx = 0 # 接受到資料的編號 self.rcvd_idx = 0 # 快取區 self.reorder_dict = {} self.workers = [ multiprocessing.Process( target=_worker_loop, args=(self.dataset, self.index_queues[i], self.worker_result_queue, self.collate_fn, base_seed + i, self.worker_init_fn, i)) for i in range(self.num_workers)] # 初始化相應的程序,目標函式為_worker_loop # 引數:dataset(用於資料讀取),index_queues[i]為worker對應的index_queue # 以及用於輸出的queue # 此處主要用於資料讀取後的pin_memory操作,不影響多程序主邏輯,暫不展開 if self.pin_memory or self.timeout > 0: ... else: self.data_queue = self.worker_result_queue for w in self.workers: w.daemon = True # ensure that the worker exits on process exit # 將父程序設定為守護程序,保證父程序結束後,worker程序也結束,必須設定在start之前 w.start() # 下面是一些系統訊號處理邏輯,對這方面我還不太熟悉就不介紹了。 _update_worker_pids(id(self), tuple(w.pid for w in self.workers)) _set_SIGCHLD_handler() self.worker_pids_set = True # 初始化後生成2*num_workers數量個prefetch的資料,使dataloader提前工作,提升整體效率。 # prime the prefetch loop for _ in range(2 * self.num_workers): self._put_indices()

init過程有兩個函式,一個是worker_loop,另個是put_indices

a. 先看worker_loop:

def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id):
    global _use_shared_memory
    _use_shared_memory = True

    # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
    # module's handlers are executed after Python returns from C low-level
    # handlers, likely when the same fatal signal happened again already.
    # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
    _set_worker_signal_handlers()

    torch.set_num_threads(1)
    random.seed(seed)
    torch.manual_seed(seed)

    if init_fn is not None:
        init_fn(worker_id)
    
    # 父程序狀態監測
    watchdog = ManagerWatchdog()
    
    # 死迴圈查詢是否有任務傳進來
    while True:
        try:
            # 從index_queue獲取相應資料
            r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
        except queue.Empty:
            if watchdog.is_alive():
                continue
            else:
                break
        if r is None:
            break
        idx, batch_indices = r
        try:
            # 獲得以後for迴圈進行讀取資料讀取,此處和單程序的工作原理是一樣的
            # 因此時間花費和batchsize數量呈線性關係
            samples = collate_fn([dataset[i] for i in batch_indices])
            # 經過collate_fn後變成torch.Tensor
        except Exception:
            # 異常處理
            data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
        else:
            # 通過data_queue傳回處理好的batch資料
            data_queue.put((idx, samples))
            # 顯示刪除中間變數,降低記憶體消耗
            del samples

這裡就是不停地輪詢,從index_queues佇列裡獲得索引,然後通過collate_fn函式和索引獲取tensor,然後塞入data_queue

b. 再看put_indices

def _put_indices(self):
    assert self.batches_outstanding < 2 * self.num_workers
    # 預設設定是隻允許分配2*num_workers個任務,保證記憶體等資源不被耗盡
    indices = next(self.sample_iter, None)
    # 從sample_iter中拿到dataset中下一輪次的索引,用於fetch資料
    if indices is None:
        return
    self.index_queues[self.worker_queue_idx].put((self.send_idx, indices))
    # 輪詢選擇worker,找到其對應的佇列,向其中傳送工作內容(資料編號,資料索引)
    self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers
    # worker_queue_idx自增
    self.batches_outstanding += 1
    # 任務分配數+1
    self.send_idx += 1
    # 已傳送任務總數+1(下批資料編號) 

這個就是把索引塞進佇列index_queues

以上就是init,當for迴圈時,會呼叫next:

c. __next__返回一個batch

def __next__(self):
    if self.num_workers == 0:  # same-process loading  (主程序阻塞式讀取資料)
        indices = next(self.sample_iter)  # may raise StopIteration
        batch = self.collate_fn([self.dataset[i] for i in indices])
        if self.pin_memory:
            batch = pin_memory_batch(batch)
        return batch
    
    # check if the next sample has already been generated
    # 先檢視資料是否在快取dict中
    if self.rcvd_idx in self.reorder_dict:
        batch = self.reorder_dict.pop(self.rcvd_idx)
        return self._process_next_batch(batch)
    # 異常處理
    if self.batches_outstanding == 0:
        self._shutdown_workers()
        raise StopIteration
    while True:
        assert (not self.shutdown and self.batches_outstanding > 0)
        # 阻塞式的從data_queue裡面獲取處理好的批資料
        idx, batch = self._get_batch() 
        # 任務數減一
        self.batches_outstanding -= 1
        # 這一步可能會造成的週期阻塞現象
        # 每次獲取data以後,要校驗和rcvd_idx是否一致
        # 若不一致,則先把獲取到的資料放到reorder_dict這個快取dict中,繼續死迴圈
        # 直到獲取到相應的idx編號於rcvd_idx可以對應上,並將資料返回
        if idx != self.rcvd_idx:
            # store out-of-order samples
            self.reorder_dict[idx] = batch
            continue
        return self._process_next_batch(batch)

__next__裡的while True,要從data_queue裡面讀到的資料idx和rcvd_idx一致才將資料返回。因此可能會存在如下這種情況:

假設num_workers=8,現在傳送了8個數據給相應的worker,此時send_idx=8,rcvd_idx=0。過了一段時間以後,{1,2,3,5,6,7}程序資料準備完畢,此時主程序從data_queue讀取到相關的資料,但由於和rcvd_idx不匹配,只能將其放在快取裡。直到send_idx=0資料準備齊以後,才能將資料返回出去,隨後從快取中彈出2,3的資料,之後又阻塞等待idx=4的資料。即輸出的資料必須保持順序性!因此在worker變多,出現這種逆序現象可能性會更大,這種現象也會出現在非num_workrers次迭代,只要相應的rcvd_idx沒有得到相關資料,則主程序就會一直等待。

d.process_next_batch

def _process_next_batch(self, batch):
    # 序號對上以後,rcvd_idx自加1
    self.rcvd_idx += 1
    # 新增一個fetchdata任務給worker
    self._put_indices()
    if isinstance(batch, ExceptionWrapper):
        raise batch.exc_type(batch.exc_msg)
    return batch

  

這個函式注意的是,只有在__next__中,idx == self.rcvd_idx時才會呼叫,也就是可能出現多個worker已經準備好了,但是隻能放在快取區,並且無法向index_queues塞入索引,使worker無法保持活躍狀態。

最後對於for迴圈從dataloader獲取data總體流程:

for epoch in range(num_epoches):
    for data in dataloader:

對於這個for,其實就是呼叫了dataloader 的__iter__() 方法, 產生了一個DataLoaderIter,如果是num_worker>0,init裡就會建立多執行緒,並且有兩個佇列,一個是存放dataset的索引index_queues,一個是從index_queues裡拿到索引,呼叫dataset的__getitem__()方法 (如果num_worker>0就多執行緒呼叫), 然後用collate_fn來把它們打包成batch,放到data_queue佇列裡,反覆呼叫DataLoaderIter 的__next__,從data_queue中獲取batch。

參考:

Pytorch資料讀取(Dataset, DataLoader, DataLoaderIter) https://zhuanlan.zhihu.com/p/30934236

PyTorch 之 Dataset 和 Dataloaderhttps://zhuanlan.zhihu.com/p/339675188

PyTorch36.DataLoader原始碼剖析https://zhuanlan.zhihu.com/p/169497395

PyTorch DataLoader初探https://zhuanlan.zhihu.com/p/91521705

一文弄懂Pytorch的DataLoader, DataSet, Sampler之間的關係 https://zhuanlan.zhihu.com/p/76893455