【pytorch原始碼賞析】Dataset in pytorch
1. 原始碼概覽
pytorch是眾多dl工具中,比較python風格化的一種,另一個完全python化的dl工具是chainer,它的構建語言中只有python,甚至cuda也是從python端呼叫的。python風格化的好處是,使用了很多python的語言特性,讓程式碼更加簡潔,更高效。《python高階程式設計》的第2、3章,描述了部分python的高階語言特性,比如:列表推導,迭代器和生成器,裝飾器等。這些trick讓程式碼更加python化,可讀性更強,也更健壯。
pytorch的資料集部分,從原始碼可以看出,提供了2個主要的類:Dataset,DataLoader。
Dataset為抽象類,定義了兩個行為:__getitem__和__len__。也就是任何資料集,都可以len(dataset)獲得樣本的數量,dataset[i]獲得其中第i個樣本。派生了兩個類:TensorDataset,當x和y是pytorch的tensor時,可以方便地匯入;另一個ConcatDataset,用於合併多個數據集(對於實際應用特別有用)。
DataLoader是更核心的類,使用者用它來獲得每次batch的訓練資料。
dataloader.py中有2個類,DataLoader和DataLoaderIter。
DataLoader提供如下功能:
1. 儲存了dataset
2. 具有sample行為
3. 提供單執行緒/多執行緒來獲取資料集中的資料(程式碼主要實現的功能)
DataLoader有2個行為:__iter__和__len__。而__iter__這個迭代器,程式碼如下:
def __iter__(self):
return DataLoaderIter(self)
返回的正是DataLoaderIter。DataLoaderIter的功能是,根據sample指定的方法,獲取訓練樣本。sample方法有SequentialSampler, RandomSampler, BatchSampler這三種,其實是兩種:SequentialSampler和RandomSampler。如果指定了shuffle,則是隨機取樣,否則是序列取樣,然後都會使用BatchSample。
DataLoaderIter具有3個行為:__iter__,__len__和__next__。每次使用next(dataLoaderIter)來獲得一個batch。
__iter__總是和__next__一起使用,__iter__表明這個類是可以迭代的,__next__表明每次迭代的具體行為,一個例子如下:
class Testing:
def __init__(self,a,b):
self.a = a
self.b = b
def __iter__ (self):
print('itering')
return self
def next(self):
print('nexting')
if self.a <= self.b:
self.a += 1
return self.a-1
else:
raise StopIteration
myObj = Testing(1,5)
for i in myObj:
print i
itering
nexting
1
nexting
2
nexting
3
nexting
4
nexting
5
nexting
2. 使用方法
使用pytorch提供的方法操作資料集,一般分兩步:
1. 繼承Dataset,實現__getitem__和__len__方法。
2. 例項化DataLoader,一般需要指定自己的collate_fn方法。
而這正是程式碼優美的地方,把“讀取資料集”這個任務完美地解耦和,使用者只需要針對不同的資料集派生Dataset類,實現2個方法。DataLoader負責瞭如何讀取訓練樣本的行為,只需要例項化即可,還可以通過設定collate_fn定製化自己的具體讀取行為。