Torch的Dataloader類原始碼以及簡單解析
阿新 • • 發佈:2021-08-04
import torch import torch.multiprocessing as multiprocessing from . import SequentialSampler, RandomSampler, BatchSampler from . import _utils import threading from torch._six import queue default_collate = _utils.collate.default_collate class DataLoader(object): __initialized = False def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): self.dataset = dataset self.batch_size = batch_size self.num_workers = num_workers self.collate_fn = collate_fn self.pin_memory = pin_memory self.drop_last = drop_last self.timeout = timeout self.worker_init_fn = worker_init_fn if timeout < 0: raise ValueError('timeout option should be non-negative') if batch_sampler is not None: # 有batch_sampler之後,其他的什麼東西都不能要了 if batch_size > 1 or shuffle or sampler is not None or drop_last: raise ValueError('batch_sampler option is mutually exclusive ' 'with batch_size, shuffle, sampler, and drop_last') self.batch_size = None self.drop_last = None if sampler is not None and shuffle: # sampler 和shuffle不能相容 raise ValueError('sampler option is mutually exclusive with ' 'shuffle') if self.num_workers < 0: raise ValueError('num_workers option cannot be negative; ' 'use num_workers=0 to disable multiprocessing.') if batch_sampler is None: if sampler is None: if shuffle: sampler = RandomSampler(dataset) else: sampler = SequentialSampler(dataset) batch_sampler = BatchSampler(sampler, batch_size, drop_last) self.sampler = sampler self.batch_sampler = batch_sampler self.__initialized = True def __setattr__(self, attr, val): if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'): raise ValueError('{} attribute should not be set after {} is ' 'initialized'.format(attr, self.__class__.__name__)) super(DataLoader, self).__setattr__(attr, val) def __iter__(self): return _DataLoaderIter(self) def __len__(self): return len(self.batch_sampler)
使用方法大致如下:
for i, (input, target) in enumerate(train_data):
主要是_DataloaderIter這個類比較重要。
簡單的來講,有以下幾點比較重要,或者說,比較不太容易懂。
- _ _ iter _ _() 和 _ _ next _ ()表示一個類是迭代器。 _ _ iter _ _()返回一個特殊的迭代器物件。
- Queue在使用的時候,當queue為空,queue.get()會阻塞,阻塞態的時候,如果其他程序/執行緒有get操作,本執行緒會被通知,然後get成功。當資料滿了,queue.put會阻塞。
- 沒有多執行緒的時候,batch = self.collate_fn([self.dataset[i] for i in indices]),使用index轉化為data,資料。也就是(image,label)。
- 多執行緒的時候,為每一個執行緒建立index_queues。共享一個worker_result_queue資料佇列。在_worker_loop中載入資料。
class _DataLoaderIter(object): """Iterates once over the DataLoader's dataset, as specified by the sampler""" # NOTE [ Data Loader Multiprocessing Shutdown Logic ] # Our data model looks like this (queues are indicated with curly brackets): # # main process || # | || # {index_queue} || # | || # worker processes || DATA # | || # {worker_result_queue} || FLOW # | || # pin_memory_thread of main process || DIRECTION # | || # {data_queue} || # | || # data output \/ # def __init__(self, loader): self.dataset = loader.dataset self.collate_fn = loader.collate_fn self.batch_sampler = loader.batch_sampler self.num_workers = loader.num_workers self.pin_memory = loader.pin_memory and torch.cuda.is_available() self.timeout = loader.timeout self.sample_iter = iter(self.batch_sampler) base_seed = torch.LongTensor(1).random_().item() if self.num_workers > 0: self.worker_init_fn = loader.worker_init_fn self.worker_queue_idx = 0 self.worker_result_queue = multiprocessing.Queue() self.batches_outstanding = 0 self.worker_pids_set = False self.shutdown = False self.send_idx = 0 self.rcvd_idx = 0 self.reorder_dict = {} self.done_event = multiprocessing.Event() self.index_queues = [] self.workers = [] for i in range(self.num_workers): # 啟動num_workers那麼多個程序 index_queue = multiprocessing.Queue() index_queue.cancel_join_thread() w = multiprocessing.Process( target=_utils.worker._worker_loop,# 目的是啟動_worker_loop這個函式 args=(self.dataset, index_queue, self.worker_result_queue, self.done_event, self.collate_fn, base_seed + i, self.worker_init_fn, i))# 把idx和samples放進了全域性的worker_result_queue裡面,這裡的idx指的不是batch的indexes。就是用了多個執行緒,往worker_result_queue中填滿了資料而已。 w.daemon = True # NB: Process.start() w.start() self.index_queues.append(index_queue) self.workers.append(w) if self.pin_memory: # 貌似pin_memory的作用就是賦值一下tensor去GPU self.data_queue = queue.Queue() pin_memory_thread = threading.Thread( target=_utils.pin_memory._pin_memory_loop, args=(self.worker_result_queue, self.data_queue, torch.cuda.current_device(), self.done_event)) pin_memory_thread.daemon = True pin_memory_thread.start() # Similar to workers (see comment above), we only register pin_memory_thread once it is started. self.pin_memory_thread = pin_memory_thread else: self.data_queue = self.worker_result_queue # 這裡不是很懂,設定pids _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self.workers)) _utils.signal_handling._set_SIGCHLD_handler() self.worker_pids_set = True # prime the prefetch loop for _ in range(2 * self.num_workers): # 為什麼*2,表示不是很懂,這裡相當於載入了2*num_workers個batch的資料。大概是說,初始化的時候,給定足量的資料在裡面。 self._put_indices() def __len__(self): return len(self.batch_sampler) def _get_batch(self): # 從data_queue中取得資料 if self.timeout > 0: try: return self.data_queue.get(timeout=self.timeout) # 從data_queue中get資料 except queue.Empty: raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout)) elif self.pin_memory: while self.pin_memory_thread.is_alive(): #先判斷一下pin_memory的執行緒是否還活著 try: return self.data_queue.get(timeout=_utils.MP_STATUS_CHECK_INTERVAL) except queue.Empty: continue else: # while condition is false, i.e., pin_memory_thread died. raise RuntimeError('Pin memory thread exited unexpectedly') else: return self.data_queue.get() def __next__(self): if self.num_workers == 0: # same-process loading indices = next(self.sample_iter) # may raise StopIteration batch = self.collate_fn([self.dataset[i] for i in indices]) if self.pin_memory: batch = _utils.pin_memory.pin_memory_batch(batch) return batch # check if the next sample has already been generated # 這裡,出現了的rcvd_idx可以用一個dict存起來。 if self.rcvd_idx in self.reorder_dict: batch = self.reorder_dict.pop(self.rcvd_idx) return self._process_next_batch(batch) # 在outstandings這個東西消耗完之後,就直接shutdown workers, raise StopIteration if self.batches_outstanding == 0: self._shutdown_workers() raise StopIteration while True: assert (not self.shutdown and self.batches_outstanding > 0) idx, batch = self._get_batch() self.batches_outstanding -= 1 if idx != self.rcvd_idx: # 這裡的機制就必須按照rcvd_idx的順序來。 # store out-of-order samples self.reorder_dict[idx] = batch continue return self._process_next_batch(batch) next = __next__ # Python 2 compatibility def __iter__(self): return self def _put_indices(self): assert self.batches_outstanding < 2 * self.num_workers indices = next(self.sample_iter, None) if indices is None: return self.index_queues[self.worker_queue_idx].put((self.send_idx, indices)) self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers self.batches_outstanding += 1 self.send_idx += 1 def _process_next_batch(self, batch): self.rcvd_idx += 1 self._put_indices() if isinstance(batch, _utils.ExceptionWrapper): raise batch.exc_type(batch.exc_msg) return batch def __getstate__(self): """ TODO:為HogWild新增有限的picking支援,以便跨多個執行緒共享迭代器。 最好的方法可能是將示例推送到單獨的執行緒,然後只共享資料佇列, 但如果沒有非阻塞的API,則傳送結束訊號是很困難的。 """ raise NotImplementedError("_DataLoaderIter cannot be pickled") def _shutdown_workers(self): # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the logic of this function. python_exit_status = _utils.python_exit_status if python_exit_status is True or python_exit_status is None: # See (2) of the note. If Python is shutting down, do no-op. return # Normal exit when last reference is gone / iterator is depleted. See (1) and the second half of the note. if not self.shutdown: self.shutdown = True # Removes pids from the C side data structure first so worker termination afterwards won't trigger false positive error report. if self.worker_pids_set: _utils.signal_handling._remove_worker_pids(id(self)) self.worker_pids_set = False self.done_event.set() # Exit `pin_memory_thread` first because exiting workers may leave # corrupted data in `worker_result_queue` which `pin_memory_thread` reads from. if hasattr(self, 'pin_memory_thread'): self.worker_result_queue.cancel_join_thread() self.worker_result_queue.put(None) self.pin_memory_thread.join() self.worker_result_queue.close() # Exit workers now. for q in self.index_queues: q.put(None) # Indicate that no more data will be put on this queue by the current process. q.close() for w in self.workers: w.join() def __del__(self): if self.num_workers > 0: self._shutdown_workers()
def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id):
global _use_shared_memory
_use_shared_memory = True
# Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
# module's handlers are executed after Python returns from C low-level
# handlers, likely when the same fatal signal happened again already.
# https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
_set_worker_signal_handlers()
torch.set_num_threads(1)
random.seed(seed)
torch.manual_seed(seed)
if init_fn is not None: # 初始化worker
init_fn(worker_id)
watchdog = ManagerWatchdog()
while True:
try:
r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
except queue.Empty:
if watchdog.is_alive():
continue
else:
break
if r is None:
break
idx, batch_indices = r
try:
samples = collate_fn([dataset[i] for i in batch_indices])
except Exception:
data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
else:
data_queue.put((idx, samples)) # 把idx和samples放進了全域性的worker_result_queue裡面
del samples