pytorch學習筆記(十四): DataLoader原始碼閱讀
pytorch
資料載入部分的 介面可以說是現存 深度學習框架中設計的最好的, 給了我們足夠的靈活性。本博文就對 pytorch 的多執行緒載入 模組(DataLoader
) 進行原始碼上的註釋。
輸入流水線
pytorch
的輸入流水線的操作順序是這樣的:
- 建立一個 Dataset 物件
- 建立一個 DataLoader 物件
- 不停的 迴圈 這個 DataLoader 物件
dataset = MyDataset()
dataloader = DataLoader(dataset)
num_epoches = 100
for epoch in range(num_epoches):
for data in dataloader:
....
在之前文章也提到過,如果現有的 Dataset
不能夠滿足需求,我們也可以自定義 Dataset
,通過繼承 torch.utils.data.Dataset
。在繼承的時候,需要 override
三個方法。
__init__
: 用來初始化資料集__getitem__
__len__
從本文中,您可以看到 __getitem__
和 __len__
在 DataLoader
中是如何被使用的。
DataLoader
從DataLoader
看起,下面是原始碼。為了方便起見,採用在原始碼中添加註釋的形式進行解讀。
class DataLoader(object):
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None,
num_workers=0, collate_fn=default_collate, pin_memory=False,
drop_last=False):
self.dataset = dataset
self.batch_size = batch_size
self.num_workers = num_workers
self.collate_fn = collate_fn
self.pin_memory = pin_memory
self.drop_last = drop_last
if batch_sampler is not None:
if batch_size > 1 or shuffle or sampler is not None or drop_last:
raise ValueError('batch_sampler is mutually exclusive with '
'batch_size, shuffle, sampler, and drop_last')
if sampler is not None and shuffle:
raise ValueError('sampler is mutually exclusive with shuffle')
if batch_sampler is None:
if sampler is None:
if shuffle:
# dataset.__len__() 在 Sampler 中被使用。
# 目的是生成一個 長度為 len(dataset) 的 序列索引(隨機的)。
sampler = RandomSampler(dataset)
else:
# dataset.__len__() 在 Sampler 中被使用。
# 目的是生成一個 長度為 len(dataset) 的 序列索引(順序的)。
sampler = SequentialSampler(dataset)
# Sampler 是個迭代器,一次之只返回一個 索引
# BatchSampler 也是個迭代器,但是一次返回 batch_size 個 索引
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
self.sampler = sampler
self.batch_sampler = batch_sampler
def __iter__(self):
return DataLoaderIter(self)
def __len__(self):
return len(self.batch_sampler)
# 以下兩個程式碼是等價的
for data in dataloader:
...
# 等價與
iterr = iter(dataloader)
while True:
try:
next(iterr)
except StopIteration:
break
在 DataLoader
中,iter(dataloader)
返回的是一個 DataLoaderIter
物件, 這個才是我們一直 next
的 物件。
下面會先介紹一下 幾個 Sampler
, 然後介紹 核心部分 DataLoaderIter
。
RandomSampler, SequentialSampler, BatchSampler
首先,是 RandomSampler
, iter(randomSampler)
會返回一個可迭代物件,這個可迭代物件 每次 next
都會輸出當前要取樣的 index
,SequentialSampler
也是一樣,只不過她產生的 index
是順序的
class RandomSampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(torch.randperm(len(self.data_source)).long())
def __len__(self):
return len(self.data_source)
BatchSampler
是一個普通 Sampler
的 wrapper
, 普通Sampler
一次僅產生一個 index
, 而 BatchSampler
一次產生一個 batch
的 indices
。
class BatchSampler(object):
def __init__(self, sampler, batch_size, drop_last):
# 這裡的 sampler 是 RandomSampler 或者 SequentialSampler
# 他們每一次吐出一個 idx
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self):
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
DataLoaderIter
self.index_queue
中存放是(batch_idx, sample_indices)
,其中batch_idx
是個int
值,sample_indices
是個list
, 存放了 組成batch
的sample indices
。self.data_queue
中存放的是(batch_idx, samples)
, 其中samples
是 一個mini-batch
的樣本self.send_idx
表示:這次 放到self.index_queue
中的batch_id
self.rcvd_idx
表示:這次要取的batch_id
self.batches_outstanding
表示:
class DataLoaderIter(object):
"Iterates once over the DataLoader's dataset, as specified by the sampler"
def __init__(self, loader):
# loader 是 DataLoader 物件
self.dataset = loader.dataset
# 這個留在最後一個部分介紹
self.collate_fn = loader.collate_fn
self.batch_sampler = loader.batch_sampler
# 表示 開 幾個程序。
self.num_workers = loader.num_workers
# 是否使用 pin_memory
self.pin_memory = loader.pin_memory
self.done_event = threading.Event()
# 這樣就可以用 next 操作 batch_sampler 了
self.sample_iter = iter(self.batch_sampler)
if self.num_workers > 0:
# 用來放置 batch_idx 的佇列,其中元素的是 一個 list,其中放了一個 batch 內樣本的索引
self.index_queue = multiprocessing.SimpleQueue()
# 用來放置 batch_data 的佇列,裡面的 元素的 一個 batch的 資料
self.data_queue = multiprocessing.SimpleQueue()
# 當前已經準備好的 batch 的數量(可能有些正在 準備中)
# 當為 0 時, 說明, dataset 中已經沒有剩餘資料了。
# 初始值為 0, 在 self._put_indices() 中 +1,在 self.__next__ 中減一
self.batches_outstanding = 0
self.shutdown = False
# 用來記錄 這次要放到 index_queue 中 batch 的 idx
self.send_idx = 0
# 用來記錄 這次要從的 data_queue 中取出 的 batch 的 idx
self.rcvd_idx = 0
# 因為多執行緒,可能會導致 data_queue 中的 batch 亂序
# 用這個來保證 batch 的返回 是 idx 升序出去的。
self.reorder_dict = {}
# 這個地方就開始 開多程序了,一共開了 num_workers 個程序
# 執行 _worker_loop , 下面將介紹 _worker_loop
self.workers = [
multiprocessing.Process(
target=_worker_loop,
args=(self.dataset, self.index_queue, self.data_queue, self.collate_fn))
for _ in range(self.num_workers)]
for w in self.workers:
w.daemon = True # ensure that the worker exits on process exit
w.start()
if self.pin_memory:
in_data = self.data_queue
self.data_queue = queue.Queue()
self.pin_thread = threading.Thread(
target=_pin_memory_loop,
args=(in_data, self.data_queue, self.done_event))
self.pin_thread.daemon = True
self.pin_thread.start()
# prime the prefetch loop
# 初始化的時候,就將 2*num_workers 個 (batch_idx, sampler_indices) 放到 index_queue 中。
for _ in range(2 * self.num_workers):
self._put_indices()
def __len__(self):
return len(self.batch_sampler)
def __next__(self):
if self.num_workers == 0: # same-process loading
indices = next(self.sample_iter) # may raise StopIteration
batch = self.collate_fn([self.dataset[i] for i in indices])
if self.pin_memory:
batch = pin_memory_batch(batch)
return batch
# check if the next sample has already been generated
if self.rcvd_idx in self.reorder_dict:
batch = self.reorder_dict.pop(self.rcvd_idx)
return self._process_next_batch(batch)
if self.batches_outstanding == 0:
# 說明沒有 剩餘 可操作資料了, 可以停止 worker 了
self._shutdown_workers()
raise StopIteration
while True:
# 這裡的操作就是 給 亂序的 data_queue 排一排 序
assert (not self.shutdown and self.batches_outstanding > 0)
idx, batch = self.data_queue.get()
# 一個 batch 被 返回,batches_outstanding -1
self.batches_outstanding -= 1
if idx != self.rcvd_idx:
# store out-of-order samples
self.reorder_dict[idx] = batch
continue
# 返回的時候,再向 indice_queue 中 放下一個 (batch_idx, sample_indices)
return self._process_next_batch(batch)
next = __next__ # Python 2 compatibility
def __iter__(self):
return self
def _put_indices(self):
assert self.batches_outstanding < 2 * self.num_workers
indices = next(self.sample_iter, None)
if indices is None:
return
self.index_queue.put((self.send_idx, indices))
self.batches_outstanding += 1
self.send_idx += 1
def _process_next_batch(self, batch):
self.rcvd_idx += 1
# 放下一個 (batch_idx, sample_indices)
self._put_indices()
if isinstance(batch, ExceptionWrapper):
raise batch.exc_type(batch.exc_msg)
return batch
def __getstate__(self):
# TODO: add limited pickling support for sharing an iterator
# across multiple threads for HOGWILD.
# Probably the best way to do this is by moving the sample pushing
# to a separate thread and then just sharing the data queue
# but signalling the end is tricky without a non-blocking API
raise NotImplementedError("DataLoaderIterator cannot be pickled")
def _shutdown_workers(self):
if not self.shutdown:
self.shutdown = True
self.done_event.set()
for _ in self.workers:
# shutdown 的時候, 會將一個 None 放到 index_queue 中
# 如果 _worker_loop 獲得了這個 None, _worker_loop 將會跳出無限迴圈,將會結束執行
self.index_queue.put(None)
def __del__(self):
if self.num_workers > 0:
self._shutdown_workers()
__worker_loop
這部分是 多程序 執行的程式碼:他從index_queue
中 取索引,然後處理資料,然後再將 處理好的 batch 資料放到 data_queue
中。
def _worker_loop(dataset, index_queue, data_queue, collate_fn):
global _use_shared_memory
_use_shared_memory = True
torch.set_num_threads(1)
while True:
r = index_queue.get()
if r is None:
# 想 data_queue 中放 None
data_queue.put(None)
break
idx, batch_indices = r
try:
# 這裡就可以看到 dataset.__getiterm__ 的作用了。
# 傳到 collate_fn 的資料是 list of ...
samples = collate_fn([dataset[i] for i in batch_indices])
except Exception:
data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
else:
data_queue.put((idx, samples))
collate_fn
我們
__getiterm__
經常返回的是 (img_tensor, label),所以 放入
collate_fn
的 引數就是[(img_tensor, label), ....]
.batch[0]
就是(img_tensor, label)
, 也就是collections.Sequence
型別。
def default_collate(batch):
"Puts each data field into a tensor with outer dimension batch size"
if torch.is_tensor(batch[0]):
out = None
if _use_shared_memory:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
# 計算 batch 中所有 元素的個數
numel = sum([x.numel() for x in batch])
# 沒有找到對應的 api 。。。。。。
storage = batch[0].storage()._new_shared(numel)
out = batch[0].new(storage)
return torch.stack(batch, 0, out=out)
elif type(batch[0]).__module__ == 'numpy':
elem = batch[0]
if type(elem).__name__ == 'ndarray':
return torch.stack([torch.from_numpy(b) for b in batch], 0)
if elem.shape == (): # scalars
py_type = float if elem.dtype.name.startswith('float') else int
return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
elif isinstance(batch[0], int):
return torch.LongTensor(batch)
elif isinstance(batch[0], float):
return torch.DoubleTensor(batch)
elif isinstance(batch[0], string_classes):
return batch
elif isinstance(batch[0], collections.Mapping):
return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
elif isinstance(batch[0], collections.Sequence):
transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]
raise TypeError(("batch must contain tensors, numbers, dicts or lists; found {}"
.format(type(batch[0]))))
對於 image captioning
任務,既有圖片,又有文字,pytorch 官方開源的工具箱 torchtext
使得文字資料的處理非常簡單,可以通過自定義 collate_fn
的方式將 DataLoader
與 torchtext
完美的整合起來。
總結
data_queue
中最多有2*num_worker
個batch
-
注意
Queue的特點
- 當裡面沒有資料時:
queue.get()
會阻塞, 阻塞的時候,其它 程序/執行緒 如果有queue.put()
操作,本 執行緒/程序 會被通知, 然後就可以get
成功。 - 當資料滿了:
queue.put()
會阻塞
相關推薦
pytorch學習筆記(十四): DataLoader原始碼閱讀
pytorch 資料載入部分的 介面可以說是現存 深度學習框架中設計的最好的, 給了我們足夠的靈活性。本博文就對 pytorch 的多執行緒載入 模組(DataLoader) 進行原始碼上的註釋。 輸入流水線 pytorch 的輸入流水線的操作順序是這
Java框架spring Boot學習筆記(十四):log4j介紹
inf alt 技術分享 images 使用 image 詳細 配置文件 -128 功能 日誌功能,通過log4j可以看到程序運行過程的詳細信息。 使用 導入log4j的jar包 復制log4j的配置文件,復制到src下面 3.設置日誌級別
javaweb學習筆記(十四):JSP(4)
目錄 製作高仿的JSTL標籤庫之核心標籤庫 《1》xiaohua.tld檔案: 《2》依附的各個類: 《3》imitate.core.jsp檔案: 《4》瀏覽器檢視: 製作高仿的JSTL標籤庫之核心標籤庫 通過自定義標籤,製
機器學習筆記(十四):TensorFlow實戰六(經典卷積神經網路:AlexNet )
1 - 引言 2012年,Imagenet比賽冠軍的model——Alexnet [2](以第一作者alex命名)。這個網路算是一個具有突破性意義的模型 首先它證明了CNN在複雜模型下的有效性,然後GPU實現使得訓練在可接受的時間範圍內得到結果,讓之後的網路模型構建變得更加複雜,並且通過
機器學習筆記(十四):異常檢測
目錄 1)Problem motivation 2)Gaussian distribution 3)Algorithm 4)Developing and evaluating an anomaly detection system 5)Anomaly detection vs
pytorch學習筆記(十七):python 端擴充套件 pytorch
pytorch 雖然提供了很多的 op 使得我們很容易的使用。但是當已有的 op 無法滿足我們的要求的時候,那就需要自己動手來擴充套件。 pytorch 提供了兩種方式來擴充套件 pytorch 的基礎
pytorch學習筆記(十一):fine-tune 預訓練的模型
torchvision 中包含了很多預訓練好的模型,這樣就使得 fine-tune 非常容易。本文主要介紹如何 fine-tune torchvision 中預訓練好的模型。 安裝 pip install torchvision 如何 fine
pytorch學習筆記(十二):詳解 Module 類
Module 是 pytorch 提供的一個基類,每次我們要 搭建 自己的神經網路的時候都要繼承這個類,繼承這個類會使得我們 搭建網路的過程變得異常簡單。 本文主要關注 Module 類的內部是怎麼樣
pytorch學習筆記(十六):pytorch 寫程式碼時應該注意
當網路中有 dropout,bn 的時候。訓練的要記得 net.train(), 測試 要記得 net.eval() 在測試的時候 建立輸入 Variable 的時候 要記得 volatile=Tru
機器學習筆記(十二):TensorFlow實戰四(影象識別與卷積神經網路)
1 - 卷積神經網路常用結構 1.1 - 卷積層 我們先來介紹卷積層的結構以及其前向傳播的演算法。 一個卷積層模組,包含以下幾個子模組: 使用0擴充邊界(padding) 卷積視窗過濾器(filter) 前向卷積 反向卷積(可選) 1.1
機器學習筆記(十二):TensorFlow實現四(影象識別與卷積神經網路)
1 - 卷積神經網路常用結構 1.1 - 卷積層 我們先來介紹卷積層的結構以及其前向傳播的演算法。 一個卷積層模組,包含以下幾個子模組: 使用0擴充邊界(padding) 卷積視窗過濾器(filter) 前向卷積 反向卷積(可選) 1.1.2 - 邊界填充
OpenCV2學習筆記(十五):利用Cmake高速查找OpenCV函數源代碼
one 生成 img log 分享 lan 學習筆記 全部 modules 在使用OpenCV時,在對一個函數的調用不是非常了解的情況下,通常希望查到該函數的官方聲明。而假設想進一步研究OpenCV的函數,則必須深入到源碼。在VS中我們能夠選中想要查
如鵬網學習筆記(十四)ASP.NET
表單參數 form表單 web服務 exp 序列化 date 文字 arr 處理程序 Asp.net筆記 一、Socket類 進行網絡編程的類,可以在兩臺計算機之間進行網絡通訊 過程: 向服務器發送指令: GET /index.html HTTP
EF學習筆記(十一):實施繼承
long cannot oid data- turn cati com list pac 學習總目錄:ASP.NET MVC5 及 EF6 學習筆記 - (目錄整理) 上篇鏈接:EF學習筆記(十) 處理並發 本篇原文鏈接:Implementing Inheritance 面
Java學習筆記(十五):import關鍵字
http 技術分享 import logs java學習筆記 .cn 關鍵字 blog ava Java學習筆記(十五):import關鍵字
Java學習筆記(十五):this關鍵字
bsp java image nbsp his this mage 學習筆記 筆記 Java學習筆記(十五):this關鍵字
Java學習筆記(十六):static關鍵字
ima 關鍵字 static關鍵字 es2017 java學習筆記 sta com 筆記 nbsp Java學習筆記(十六):static關鍵字
Java學習筆記(十七):super關鍵字
mage cnblogs 分享 關鍵字 super關鍵字 log .cn nbsp java Java學習筆記(十七):super關鍵字
R語言學習筆記(十一):廣義線性模型
學習筆記 Education 5.0 1.3 style only 可能性 div erro #Logistic 回歸 install.packages("AER") data(Affairs,package="AER") summary(Affairs) a
R語言學習筆記(十六):處理缺失值
ima 結果 cti img dataset case prop .com log #識別缺失值 install.packages("VIM") data(sleep,package="VIM") #列出沒有缺失值的行 sleep[complete.case