pytorch :: Dataloader中的迭代器和生成器應用
在使用pytorch訓練模型,經常需要載入大量圖片資料,因此pytorch提供了好用的資料載入工具Dataloader。
為了實現小批量迴圈讀取大型資料集,在Dataloader類具體實現中,使用了迭代器和生成器。
這一應用場景正是python中迭代器模式的意義所在,因此本文對Dataloader中程式碼進行解讀,可以更好的理解python中迭代器和生成器的概念。
本文的內容主要有:
- 解釋python中的迭代器和生成器概念
- 解讀pytorch中Dataloader程式碼,如何使用迭代器和生成器實現資料載入
python迭代基礎
python中圍繞著迭代有以下概念:
- 可迭代物件 iterables
- 迭代器 iterator
- 生成器 generator
這三個概念互相關聯,並不是孤立的。在可迭代物件的基礎上發展了迭代器,在迭代器的基礎上又發展了生成器。
學習這些概念的名詞解釋沒有多大意義。程式設計中很多的抽象概念都是為了更好的實現某些功能,才去人為創造的協議和模式。
因此,要理解它們,需要探究概念背後的邏輯,為什麼這樣設計?要解決的真正問題是什麼?在哪些場景下應用是最好的?
迭代模式首先要解決的基礎問題是,需要按一定順序獲取集合內部資料,比如迴圈某個list。
當資料很小時,不會有問題。但當讀取大量資料時,一次性讀取會超出記憶體限制,因此想出以下方法:
- 把大的資料分成幾個小塊,分批處理
- 惰性的取值方式,按需取值
迴圈讀資料可分為下面三種應用場景,對應著容器(可迭代物件),迭代器和生成器:
for x in container
: 為了遍歷python內部序列容器(如list), 這些型別內部實現了__getitem__() 方法,可以從0開始按順序遍歷序列容器中的元素。for x in iterator
: 為了迴圈使用者自定義的迭代器,需要實現__iter__和__next__方法,__iter__是迭代協議,具體每次迭代的執行邏輯在 __next__或next方法裡for x in generator
: 為了節省迴圈的記憶體和加速,使用生成器來實現惰性載入,在迭代器的基礎上加入了yield語句,最簡單的例子是 range(5)
程式碼示例:
# 普通迴圈 for x in list
numbers = [1, 2, 3,]
for n in numbers:
print(n) # 1,2,3
# for迴圈實際乾的事情
# iter輸入一個可迭代物件list,返回迭代器
# next方法取資料
my_iterator = iter(numbers)
next(my_iterator) # 1
next(my_iterator) # 2
next(my_iterator) # 3
next(my_iterator) # StopIteration exception
# 迭代器迴圈 for x in iterator
for i,n in enumerate(numbers):
print(i,n) # 0,1 / 1,3 / 2,3
# 生成器迴圈 for x in generator
for i in range(3):
print(i) # 0,1,2
上面示例程式碼中python內建函式iter和next的用法:
- iter函式,呼叫__iter__,返回一個迭代器
- next函式,輸入迭代器,呼叫__next__,取出資料
比較容易混淆的是__iter__和__next__兩個方法。它們的區別是:
- __iter__是為了可以迭代,真正執行取資料的邏輯是__next__方法實現的,實際呼叫是通過next(iterator)完成
- __iter__可以返回自身(return self),實際讀取資料的實現放在__next__方法
- __iter__可以和yield搭配,返回生成器物件
__iter__返回自身的做法有點類似 python中的型別系統。為了保持一致性,python中一切皆物件。
每個物件建立後,都有型別指標,而型別物件的指標指向元物件,元物件的指標指向自身。
生成器,是在__iter__方法中加入yield語句,好處有:
- 減少迴圈判斷邏輯的複雜度
- 惰性取值,節省記憶體和時間
yield作用:
- 代替函式中的return語句
- 記住上一次迴圈迭代器內部元素的位置
三種迴圈模式常用函式
for x in container
方法:
list, deque, …
set, frozensets, …
dict, defaultdict, OrderedDict, Counter, …
tuple, namedtuple, …
str
for x in iterator
方法:
enumerate()
# 加上list的indexsorted()
# 排序listreversed()
# 倒序listzip()
# 合併list
for x in generator
方法:
range()
map()
filter()
reduce()
[x for x in list(...)]
Dataloder原始碼分析
pytorch採用for x in iterator
模式,從Dataloader類中讀取資料。
- 為了實現該迭代模式,在Dataloader內部實現__iter__方法,實際返回的是_DataLoaderIter類。
- _DataLoaderIter類裡面,實現了 __iter__方法,返回自身,具體執行讀資料的邏輯,在__next__方法中。
以下程式碼只截取了單執行緒下的資料讀取。
class DataLoader(object):
r"""
Data loader. Combines a dataset and a sampler, and provides
single- or multi-process iterators over the dataset.
"""
def __init__(self, dataset, batch_size=1, shuffle=False, ...):
self.dataset = dataset
self.batch_sampler = batch_sampler
...
def __iter__(self):
return _DataLoaderIter(self)
def __len__(self):
return len(self.batch_sampler)
class _DataLoaderIter(object):
r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
def __init__(self, loader):
self.sample_iter = iter(self.batch_sampler)
...
def __next__(self):
if self.num_workers == 0: # same-process loading
indices = next(self.sample_iter) # may raise StopIteration
batch = self.collate_fn([self.dataset[i] for i in indices])
if self.pin_memory:
batch = pin_memory_batch(batch)
return batch
...
def __iter__(self):
return self
Dataloader類中讀取資料Index的方法,採用了 for x in generator
方式,但是呼叫採用iter和next函式
- 構建隨機取樣類RandomSampler,內部實現了 __iter__方法
- __iter__方法內部使用了 yield,迴圈遍歷資料集,當數量達到batch_size大小時,就返回
- 例項化隨機取樣類,傳入iter函式,返回一個迭代器
- next會呼叫隨機取樣類中生成器,返回相應的index資料
class RandomSampler(object):
"""random sampler to yield a mini-batch of indices."""
def __init__(self, batch_size, dataset, drop_last=False):
self.dataset = dataset
self.batch_size = batch_size
self.num_imgs = len(dataset)
self.drop_last = drop_last
def __iter__(self):
indices = np.random.permutation(self.num_imgs)
batch = []
for i in indices:
batch.append(i)
if len(batch) == self.batch_size:
yield batch
batch = []
## if images not to yield a batch
if len(batch)>0 and not self.drop_last:
yield batch
def __len__(self):
if self.drop_last:
return self.num_imgs // self.batch_size
else:
return (self.num_imgs + self.batch_size - 1) // self.batch_size
batch_sampler = RandomSampler(batch_size. dataset)
sample_iter = iter(batch_sampler)
indices = next(sample_iter)
總結
本文總結了python中迴圈的三種模式:
for x in container
可迭代物件for x in iterator
迭代器for x in generator
生成器
pytorch中的資料載入模組 Dataloader,使用生成器來返回資料的索引,使用迭代器來返回需要的張量資料,可以在大量資料情況下,實現小批量迴圈迭代式的讀取,避免了記憶體不足問題。
參考文章
- Looping Like a Pro in Python PyCon 2017
- 迭代器和生成器
- 流暢的Python-第14章:可迭代的物件、迭代器和生成器
- pytorch-dataloader原始碼