1. 程式人生 > 實用技巧 >Pytorch--torch.utils.data.DataLoader解讀

Pytorch--torch.utils.data.DataLoader解讀

torch.utils.data.DataLoader是Pytorch中資料讀取的一個重要介面,其在dataloader.py中定義,基本上只要是用oytorch來訓練模型基本都會用到該介面,該介面主要用來將自定義的資料讀取介面的輸出或者PyTorch已有的資料讀取介面的輸入按照batch size封裝成Tensor,後續只需要再包裝成Variable即可作為模型的輸入,主要包括DataLoader和DataLoaderIter兩個類。

dataloader.py指令碼的的github地址:https://github.com/pytorch/pytorch/blob/master/torch/utils/data/dataloader.py

DataLoader類原始碼如下。先看看__init__中的幾個重要的輸入:1、dataset,這個就是PyTorch已有的資料讀取介面(比如torchvision.datasets.ImageFolder)或者自定義的資料介面的輸出,該輸出要麼是torch.utils.data.Dataset類的物件,要麼是繼承自torch.utils.data.Dataset類的自定義類的物件。2、batch_size,根據具體情況設定即可。3、shuffle,一般在訓練資料中會採用。4、collate_fn,是用來處理不同情況下的輸入dataset的封裝,一般採用預設即可,除非你自定義的資料讀取輸出非常少見。5、batch_sampler,從註釋可以看出,其和batch_size、shuffle等引數是互斥的,一般採用預設。6、sampler,從程式碼可以看出,其和shuffle是互斥的,一般預設即可。7、num_workers,從註釋可以看出這個引數必須大於等於0,0的話表示資料匯入在主程序中進行,其他大於0的數表示通過多個程序來匯入資料,可以加快資料匯入速度。8、pin_memory,註釋寫得很清楚了: pin_memory (bool, optional): If True, the data loader will copy tensors into CUDA pinned memory before returning them. 也就是一個數據拷貝的問題。9、timeout,是用來設定資料讀取的超時時間的,但超過這個時間還沒讀取到資料的話就會報錯。

  在__init__中,RandomSampler類表示隨機取樣且不重複,所以起到的就是shuffle的作用。BatchSampler類則是把batch size個RandomSampler類物件封裝成一個,這樣就實現了隨機選取一個batch的目的。這兩個取樣類都是定義在sampler.py指令碼中,地址:https://github.com/pytorch/pytorch/blob/master/torch/utils/data/sampler.py。以上這些都是初始化的時候進行的。當代碼執行到要從torch.utils.data.DataLoader類生成的物件中取資料的時候,比如:
train_data=torch.utils.data.DataLoader(...)


for i, (input, target) in enumerate(train_data):
...
就會呼叫DataLoader類的__iter__方法,__iter__方法就一行程式碼:return DataLoaderIter(self),輸入正是DataLoader類的屬性。因此當呼叫__iter__方法的時候就牽扯到另外一個類:DataLoaderIter

  DataLoaderIter類原始碼如下。self.index_queue = multiprocessing.SimpleQueue()中的multiprocessing是Python中的多程序管理包,而threading則是Python中的多執行緒管理包,二者很大一部分的介面用法類似。還是照例先看看__init__,前面部分都是一些賦值操作,比較特殊的是self.sample_iter = iter(self.batch_sampler),得到的self.sample_iter可以通過next(self.sample_iter)來獲取batch size個數據的index。self.rcvd_idx表示讀取到的一個batch資料的index,初始化為0,該值在迭代讀取資料的時候會用到。if self.num_workers語句是針對多程序或單程序的情況進行初始化,如果不是設定為多程序讀取資料,那麼就不需要這些初始化操作,後面會介紹單程序資料讀取。在if語句中通過multiprocessing.SimpleQueue()類建立了一個簡單的佇列物件。multiprocessing.Process類就是構造程序的類,這裡根據設定的程序數來啟動,然後賦值給self.workers。接下來的一個for迴圈就通過呼叫start方法依次啟動self.workers中的程序。接下來關於self.pin_memory的判斷語句,該判斷語句內部主要是實現了多執行緒操作。self.pin_memory的含義在前面已經介紹過了,當為True的時候,就會把資料拷到CUDA中。self.data_queue = queue.Queue()是通過Python的queue模組初始化得到一個先進先出的佇列(queue模組也可以初始化得到先進後出的佇列,需要用queue.LifoQueue()初始化),queue模組主要應用在多執行緒讀取資料中。在threading.Thread的args引數中,第一個引數in_data就是一個程序的資料,一個程序中不同執行緒的資料也是通過佇列來維護的,這裡採用的是Python的queue模組來初始化得到一個佇列:queue.Queue()。初始化結束後,就會呼叫__next__方法,接下來介紹。
總的來說,如果設定為多程序讀取資料,那麼就會採用佇列的方式來讀,如果不是採用多程序來讀取資料,那就採用普通方式來讀。

  DataLoaderIter類的__next__方法如下,包含3個if語句和1個while語句。
第一個if語句是用來處理self.num_workers等於0的情況,也就是不採用多程序進行資料讀取,可以看出在這個if語句中先通過indices = next(self.sample_iter)獲取長度為batch size的列表:indices,這個列表的每個值表示一個batch中每個資料的index,每執行一次next操作都會讀取一批長度為batch size的indices列表。然後通過self.collate_fn函式將batch size個tuple(每個tuple長度為2,其中第一個值是資料,Tensor型別,第二個值是標籤,int型別)封裝成一個list,這個list長度為2,兩個值都是Tensor,一個是batch size個數據組成的FloatTensor,另一個是batch size個標籤組成的LongTensor。所以簡單講self.collate_fn函式就是將batch size個分散的Tensor封裝成一個Tensor。batch = pin_memory_batch(batch)中pin_memory_batch函式的作用就是將輸入batch的每個Tensor都拷貝到CUDA中,該函式後面會詳細介紹。
第二個if語句是判斷當前想要讀取的batch的index(self.rcvd_idx)是否之前已經讀出來過(已讀出來的index和batch資料儲存在self.reorder_dict字典中,可以結合最後的while語句一起看,因為self.reorder_dict字典的更新是在最後的while語句中),如果之前已經讀取過了,就根據這個index從reorder_dict字典中彈出對應的資料。最後返回batch資料的時候是 return self._process_next_batch(batch),該方法後面會詳細介紹。主要做是獲取下一個batch的資料index資訊。
第三個if語句,self.batches_outstanding的值在前面初始中呼叫self._put_indices()方法時修改了,所以假設你的程序數self.num_workers設定為3,那麼這裡self.batches_outstanding就是3*2=6,可具體看self._put_indices()方法。
最後的while迴圈就是真正用來從佇列中讀取資料的操作,最主要的就是idx, batch = self._get_batch(),通過呼叫_get_batch()方法來讀取,後面有介紹,簡單講就是呼叫了佇列的get方法得到下一個batch的資料,得到的batch一般是長度為2的列表,列表的兩個值都是Tensor,分別表示資料(是一個batch的)和標籤。_get_batch()方法除了返回batch資料外,還得到另一個輸出:idx,這個輸出表示batch的index,這個if idx != self.rcvd_idx條件語句表示如果你讀取到的batch的index不等於當前想要的index:selg,rcvd_idx,那麼就將讀取到的資料儲存在字典self.reorder_dict中:self.reorder_dict[idx] = batch,然後繼續讀取資料,直到讀取到的資料的index等於self.rcvd_idx。

  DataloaderIter類的_get_batch方法。主要根據是否設定了超時時間來操作,如果超過指定的超時時間後沒有從佇列中讀到資料就報錯,如果不設定超時時間且一致沒有從佇列中讀到資料,那麼就會一直卡著且不報錯,這部分是PyTorch後來修的一個bug。

  DataLoaderIter類的_process_next_batch方法。首先對self.rcvd_idx進行加一,也就是更新下下一個要讀取的batch資料的index。然後呼叫_put_indices()方法獲取下一個batch的每個資料的index。

  DataLoaderIter類的_put_indices方法。該方法主要實現從self.sample_iter中讀取下一個batch資料中每個資料的index:indices = next(self.sample_iter, None),注意這裡的index和前面idx是不一樣的,這裡的index是一個batch中每個資料的index,idx是一個batch的index;然後將讀取到的index通過呼叫queue物件的put方法壓到佇列self.index_queue中:self.index_queue.put((self.send_idx, indices))