1. 程式人生 > >pytorch學習筆記(十四): DataLoader原始碼閱讀

pytorch學習筆記(十四): DataLoader原始碼閱讀

pytorch 資料載入部分的 介面可以說是現存 深度學習框架中設計的最好的, 給了我們足夠的靈活性。本博文就對 pytorch 的多執行緒載入 模組(DataLoader) 進行原始碼上的註釋。

輸入流水線

pytorch 的輸入流水線的操作順序是這樣的:

  • 建立一個 Dataset 物件
  • 建立一個 DataLoader 物件
  • 不停的 迴圈 這個 DataLoader 物件
dataset = MyDataset()
dataloader = DataLoader(dataset)
num_epoches = 100
for epoch in range(num_epoches):
    for
data in dataloader: ....

在之前文章也提到過,如果現有的 Dataset 不能夠滿足需求,我們也可以自定義 Dataset,通過繼承 torch.utils.data.Dataset。在繼承的時候,需要 override 三個方法。

  • __init__: 用來初始化資料集
  • __getitem__
  • __len__

從本文中,您可以看到 __getitem____len__DataLoader 中是如何被使用的。

DataLoader

DataLoader 看起,下面是原始碼。為了方便起見,採用在原始碼中添加註釋的形式進行解讀。

class DataLoader(object):
    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, 
                 batch_sampler=None,
                 num_workers=0, collate_fn=default_collate, pin_memory=False, 
                 drop_last=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.collate_fn = collate_fn
        self.pin_memory = pin_memory
        self.drop_last = drop_last

        if
batch_sampler is not None: if batch_size > 1 or shuffle or sampler is not None or drop_last: raise ValueError('batch_sampler is mutually exclusive with ' 'batch_size, shuffle, sampler, and drop_last') if sampler is not None and shuffle: raise ValueError('sampler is mutually exclusive with shuffle') if batch_sampler is None: if sampler is None: if shuffle: # dataset.__len__() 在 Sampler 中被使用。 # 目的是生成一個 長度為 len(dataset) 的 序列索引(隨機的)。 sampler = RandomSampler(dataset) else: # dataset.__len__() 在 Sampler 中被使用。 # 目的是生成一個 長度為 len(dataset) 的 序列索引(順序的)。 sampler = SequentialSampler(dataset) # Sampler 是個迭代器,一次之只返回一個 索引 # BatchSampler 也是個迭代器,但是一次返回 batch_size 個 索引 batch_sampler = BatchSampler(sampler, batch_size, drop_last) self.sampler = sampler self.batch_sampler = batch_sampler def __iter__(self): return DataLoaderIter(self) def __len__(self): return len(self.batch_sampler)
# 以下兩個程式碼是等價的
for data in dataloader:
    ...
# 等價與
iterr = iter(dataloader)
while True:
    try:
        next(iterr)
    except StopIteration:
        break

DataLoader 中,iter(dataloader) 返回的是一個 DataLoaderIter 物件, 這個才是我們一直 next的 物件。

下面會先介紹一下 幾個 Sampler, 然後介紹 核心部分 DataLoaderIter

RandomSampler, SequentialSampler, BatchSampler

首先,是 RandomSampleriter(randomSampler) 會返回一個可迭代物件,這個可迭代物件 每次 next 都會輸出當前要取樣的 indexSequentialSampler也是一樣,只不過她產生的 index順序

class RandomSampler(Sampler):

    def __init__(self, data_source):
        self.data_source = data_source

    def __iter__(self):
        return iter(torch.randperm(len(self.data_source)).long())

    def __len__(self):
        return len(self.data_source)

BatchSampler 是一個普通 Samplerwrapper, 普通Sampler 一次僅產生一個 index, 而 BatchSampler 一次產生一個 batchindices

class BatchSampler(object):
    def __init__(self, sampler, batch_size, drop_last):
        # 這裡的 sampler 是 RandomSampler 或者 SequentialSampler
        # 他們每一次吐出一個 idx
        self.sampler = sampler
        self.batch_size = batch_size
        self.drop_last = drop_last

    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch

    def __len__(self):
        if self.drop_last:
            return len(self.sampler) // self.batch_size
        else:
            return (len(self.sampler) + self.batch_size - 1) // self.batch_size

DataLoaderIter

  1. self.index_queue 中存放是 (batch_idx, sample_indices) ,其中 batch_idx 是個 int 值, sample_indices 是個 list , 存放了 組成 batchsample indices
  2. self.data_queue 中存放的是 (batch_idx, samples), 其中 samples 是 一個 mini-batch 的樣本
  3. self.send_idx 表示:這次 放到 self.index_queue 中的 batch_id
  4. self.rcvd_idx 表示:這次要取的 batch_id
  5. self.batches_outstanding 表示:
class DataLoaderIter(object):
    "Iterates once over the DataLoader's dataset, as specified by the sampler"

    def __init__(self, loader):
        # loader 是 DataLoader 物件
        self.dataset = loader.dataset
        # 這個留在最後一個部分介紹
        self.collate_fn = loader.collate_fn
        self.batch_sampler = loader.batch_sampler
        # 表示 開 幾個程序。
        self.num_workers = loader.num_workers
        # 是否使用 pin_memory
        self.pin_memory = loader.pin_memory
        self.done_event = threading.Event()

        # 這樣就可以用 next 操作 batch_sampler 了
        self.sample_iter = iter(self.batch_sampler)

        if self.num_workers > 0:
            # 用來放置 batch_idx 的佇列,其中元素的是 一個 list,其中放了一個 batch 內樣本的索引
            self.index_queue = multiprocessing.SimpleQueue()
            # 用來放置 batch_data 的佇列,裡面的 元素的 一個 batch的 資料
            self.data_queue = multiprocessing.SimpleQueue()

            # 當前已經準備好的 batch 的數量(可能有些正在 準備中)
            # 當為 0 時, 說明, dataset 中已經沒有剩餘資料了。
            # 初始值為 0, 在 self._put_indices() 中 +1,在 self.__next__ 中減一
            self.batches_outstanding = 0 
            self.shutdown = False
            # 用來記錄 這次要放到 index_queue 中 batch 的 idx
            self.send_idx = 0
            # 用來記錄 這次要從的 data_queue 中取出 的 batch 的 idx
            self.rcvd_idx = 0
            # 因為多執行緒,可能會導致 data_queue 中的 batch 亂序
            # 用這個來保證 batch 的返回 是 idx 升序出去的。
            self.reorder_dict = {}
            # 這個地方就開始 開多程序了,一共開了 num_workers 個程序
            # 執行 _worker_loop , 下面將介紹 _worker_loop
            self.workers = [
                multiprocessing.Process(
                    target=_worker_loop,
                    args=(self.dataset, self.index_queue, self.data_queue, self.collate_fn))
                for _ in range(self.num_workers)]

            for w in self.workers:
                w.daemon = True  # ensure that the worker exits on process exit
                w.start()

            if self.pin_memory:
                in_data = self.data_queue
                self.data_queue = queue.Queue()
                self.pin_thread = threading.Thread(
                    target=_pin_memory_loop,
                    args=(in_data, self.data_queue, self.done_event))
                self.pin_thread.daemon = True
                self.pin_thread.start()

            # prime the prefetch loop
            # 初始化的時候,就將 2*num_workers 個 (batch_idx, sampler_indices) 放到 index_queue 中。
            for _ in range(2 * self.num_workers):
                self._put_indices()

    def __len__(self):
        return len(self.batch_sampler)

    def __next__(self):
        if self.num_workers == 0:  # same-process loading
            indices = next(self.sample_iter)  # may raise StopIteration
            batch = self.collate_fn([self.dataset[i] for i in indices])
            if self.pin_memory:
                batch = pin_memory_batch(batch)
            return batch

        # check if the next sample has already been generated
        if self.rcvd_idx in self.reorder_dict:
            batch = self.reorder_dict.pop(self.rcvd_idx)
            return self._process_next_batch(batch)

        if self.batches_outstanding == 0:
            # 說明沒有 剩餘 可操作資料了, 可以停止 worker 了
            self._shutdown_workers()
            raise StopIteration

        while True:
            # 這裡的操作就是 給 亂序的 data_queue 排一排 序
            assert (not self.shutdown and self.batches_outstanding > 0)
            idx, batch = self.data_queue.get()
            # 一個 batch 被 返回,batches_outstanding -1
            self.batches_outstanding -= 1
            if idx != self.rcvd_idx:
                # store out-of-order samples
                self.reorder_dict[idx] = batch
                continue
            # 返回的時候,再向 indice_queue 中 放下一個 (batch_idx, sample_indices)
            return self._process_next_batch(batch)

    next = __next__  # Python 2 compatibility

    def __iter__(self):
        return self

    def _put_indices(self):
        assert self.batches_outstanding < 2 * self.num_workers
        indices = next(self.sample_iter, None)
        if indices is None:
            return
        self.index_queue.put((self.send_idx, indices))
        self.batches_outstanding += 1
        self.send_idx += 1

    def _process_next_batch(self, batch):
        self.rcvd_idx += 1
        # 放下一個 (batch_idx, sample_indices)
        self._put_indices()
        if isinstance(batch, ExceptionWrapper):
            raise batch.exc_type(batch.exc_msg)
        return batch

    def __getstate__(self):
        # TODO: add limited pickling support for sharing an iterator
        # across multiple threads for HOGWILD.
        # Probably the best way to do this is by moving the sample pushing
        # to a separate thread and then just sharing the data queue
        # but signalling the end is tricky without a non-blocking API
        raise NotImplementedError("DataLoaderIterator cannot be pickled")

    def _shutdown_workers(self):
        if not self.shutdown:
            self.shutdown = True
            self.done_event.set()
            for _ in self.workers:
                # shutdown 的時候, 會將一個 None 放到 index_queue 中
                # 如果 _worker_loop 獲得了這個 None, _worker_loop 將會跳出無限迴圈,將會結束執行
                self.index_queue.put(None)

    def __del__(self):
        if self.num_workers > 0:
            self._shutdown_workers()

__worker_loop

這部分是 多程序 執行的程式碼:他從index_queue 中 取索引,然後處理資料,然後再將 處理好的 batch 資料放到 data_queue 中。

def _worker_loop(dataset, index_queue, data_queue, collate_fn):
    global _use_shared_memory
    _use_shared_memory = True

    torch.set_num_threads(1)
    while True:
        r = index_queue.get()
        if r is None:
            # 想 data_queue 中放 None
            data_queue.put(None)
            break
        idx, batch_indices = r
        try:
            # 這裡就可以看到 dataset.__getiterm__ 的作用了。
            # 傳到 collate_fn 的資料是 list of ...
            samples = collate_fn([dataset[i] for i in batch_indices])
        except Exception:
            data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
        else:
            data_queue.put((idx, samples))

collate_fn

  • 我們 __getiterm__ 經常返回的是 (img_tensor, label),

  • 所以 放入 collate_fn 的 引數就是 [(img_tensor, label), ....] .

  • batch[0] 就是 (img_tensor, label) , 也就是 collections.Sequence 型別。
def default_collate(batch):
    "Puts each data field into a tensor with outer dimension batch size"
    if torch.is_tensor(batch[0]):
        out = None
        if _use_shared_memory:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            # 計算 batch 中所有 元素的個數 
            numel = sum([x.numel() for x in batch])
            # 沒有找到對應的 api 。。。。。。
            storage = batch[0].storage()._new_shared(numel)
            out = batch[0].new(storage)
        return torch.stack(batch, 0, out=out)
    elif type(batch[0]).__module__ == 'numpy':
        elem = batch[0]
        if type(elem).__name__ == 'ndarray':
            return torch.stack([torch.from_numpy(b) for b in batch], 0)
        if elem.shape == ():  # scalars
            py_type = float if elem.dtype.name.startswith('float') else int
            return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
    elif isinstance(batch[0], int):
        return torch.LongTensor(batch)
    elif isinstance(batch[0], float):
        return torch.DoubleTensor(batch)
    elif isinstance(batch[0], string_classes):
        return batch
    elif isinstance(batch[0], collections.Mapping):
        return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
    elif isinstance(batch[0], collections.Sequence):
        transposed = zip(*batch)
        return [default_collate(samples) for samples in transposed]

    raise TypeError(("batch must contain tensors, numbers, dicts or lists; found {}"
                     .format(type(batch[0]))))

對於 image captioning 任務,既有圖片,又有文字,pytorch 官方開源的工具箱 torchtext 使得文字資料的處理非常簡單,可以通過自定義 collate_fn 的方式將 DataLoadertorchtext 完美的整合起來。

總結

  • data_queue 中最多有 2*num_workerbatch

注意

Queue的特點

  • 當裡面沒有資料時: queue.get() 會阻塞, 阻塞的時候,其它 程序/執行緒 如果有 queue.put() 操作,本 執行緒/程序 會被通知, 然後就可以 get 成功。
  • 當資料滿了: queue.put() 會阻塞

相關推薦

pytorch學習筆記 DataLoader原始碼閱讀

pytorch 資料載入部分的 介面可以說是現存 深度學習框架中設計的最好的, 給了我們足夠的靈活性。本博文就對 pytorch 的多執行緒載入 模組(DataLoader) 進行原始碼上的註釋。 輸入流水線 pytorch 的輸入流水線的操作順序是這

Java框架spring Boot學習筆記log4j介紹

inf alt 技術分享 images 使用 image 詳細 配置文件 -128 功能 日誌功能,通過log4j可以看到程序運行過程的詳細信息。 使用 導入log4j的jar包 復制log4j的配置文件,復制到src下面         3.設置日誌級別    

javaweb學習筆記JSP4

目錄   製作高仿的JSTL標籤庫之核心標籤庫 《1》xiaohua.tld檔案: 《2》依附的各個類: 《3》imitate.core.jsp檔案: 《4》瀏覽器檢視:   製作高仿的JSTL標籤庫之核心標籤庫 通過自定義標籤,製

機器學習筆記TensorFlow實戰六經典卷積神經網路AlexNet

1 - 引言 2012年,Imagenet比賽冠軍的model——Alexnet [2](以第一作者alex命名)。這個網路算是一個具有突破性意義的模型 首先它證明了CNN在複雜模型下的有效性,然後GPU實現使得訓練在可接受的時間範圍內得到結果,讓之後的網路模型構建變得更加複雜,並且通過

機器學習筆記異常檢測

目錄 1)Problem motivation 2)Gaussian distribution 3)Algorithm 4)Developing and evaluating an anomaly detection system 5)Anomaly detection vs

pytorch學習筆記python 端擴充套件 pytorch

pytorch 雖然提供了很多的 op 使得我們很容易的使用。但是當已有的 op 無法滿足我們的要求的時候,那就需要自己動手來擴充套件。 pytorch 提供了兩種方式來擴充套件 pytorch 的基礎

pytorch學習筆記fine-tune 預訓練的模型

torchvision 中包含了很多預訓練好的模型,這樣就使得 fine-tune 非常容易。本文主要介紹如何 fine-tune torchvision 中預訓練好的模型。 安裝 pip install torchvision 如何 fine

pytorch學習筆記詳解 Module 類

Module 是 pytorch 提供的一個基類,每次我們要 搭建 自己的神經網路的時候都要繼承這個類,繼承這個類會使得我們 搭建網路的過程變得異常簡單。 本文主要關注 Module 類的內部是怎麼樣

pytorch學習筆記pytorch 寫程式碼時應該注意

當網路中有 dropout,bn 的時候。訓練的要記得 net.train(), 測試 要記得 net.eval() 在測試的時候 建立輸入 Variable 的時候 要記得 volatile=Tru

機器學習筆記TensorFlow實戰影象識別與卷積神經網路

1 - 卷積神經網路常用結構 1.1 - 卷積層 我們先來介紹卷積層的結構以及其前向傳播的演算法。 一個卷積層模組,包含以下幾個子模組: 使用0擴充邊界(padding) 卷積視窗過濾器(filter) 前向卷積 反向卷積(可選) 1.1

機器學習筆記TensorFlow實現影象識別與卷積神經網路

1 - 卷積神經網路常用結構 1.1 - 卷積層 我們先來介紹卷積層的結構以及其前向傳播的演算法。 一個卷積層模組,包含以下幾個子模組: 使用0擴充邊界(padding) 卷積視窗過濾器(filter) 前向卷積 反向卷積(可選) 1.1.2 - 邊界填充

OpenCV2學習筆記利用Cmake高速查找OpenCV函數源代碼

one 生成 img log 分享 lan 學習筆記 全部 modules 在使用OpenCV時,在對一個函數的調用不是非常了解的情況下,通常希望查到該函數的官方聲明。而假設想進一步研究OpenCV的函數,則必須深入到源碼。在VS中我們能夠選中想要查

如鵬網學習筆記ASP.NET

表單參數 form表單 web服務 exp 序列化 date 文字 arr 處理程序 Asp.net筆記 一、Socket類   進行網絡編程的類,可以在兩臺計算機之間進行網絡通訊   過程:     向服務器發送指令:     GET /index.html HTTP

EF學習筆記實施繼承

long cannot oid data- turn cati com list pac 學習總目錄:ASP.NET MVC5 及 EF6 學習筆記 - (目錄整理) 上篇鏈接:EF學習筆記(十) 處理並發 本篇原文鏈接:Implementing Inheritance 面

Java學習筆記import關鍵字

http 技術分享 import logs java學習筆記 .cn 關鍵字 blog ava Java學習筆記(十五):import關鍵字

Java學習筆記this關鍵字

bsp java image nbsp his this mage 學習筆記 筆記 Java學習筆記(十五):this關鍵字

Java學習筆記static關鍵字

ima 關鍵字 static關鍵字 es2017 java學習筆記 sta com 筆記 nbsp Java學習筆記(十六):static關鍵字

Java學習筆記super關鍵字

mage cnblogs 分享 關鍵字 super關鍵字 log .cn nbsp java Java學習筆記(十七):super關鍵字

R語言學習筆記廣義線性模型

學習筆記 Education 5.0 1.3 style only 可能性 div erro #Logistic 回歸 install.packages("AER") data(Affairs,package="AER") summary(Affairs) a

R語言學習筆記處理缺失值

ima 結果 cti img dataset case prop .com log #識別缺失值 install.packages("VIM") data(sleep,package="VIM") #列出沒有缺失值的行 sleep[complete.case