1. 程式人生 > >自己造輪子:深度學習dataloader自己實現

自己造輪子:深度學習dataloader自己實現

自己造輪子:深度學習dataloader自己實現

**摘要:**因為計算機效能的限制,所有的深度學習框架都是採用批量隨機梯度下降,所以每次計算都要讀取batch_size的資料。這裡以自己實現的方式介紹深度學習框架實現批量讀取資料的原理,不涉及具體細節和一些邏輯,只注重大體流程和原理。

總體流程:

  • 採用yield寫一個生成器函式實現批量圖片/標註資訊的讀取
  • 採用multiprocessing/threading加速檔案讀取
  • 時間對比

深度學習大體流程

for i in range(epoch):
    data, lable = dataloader.
next(batch_size=16) # 讀取batch_size的資料 output = model(data) # 前向傳播 loss = crition(output, label) # 求損失函式 loss.backward() # 反向傳播

在dataloader的時候,一般會採用多個程序(num_workers
)加快檔案I/O的速度,避免網路反向傳播過了,還沒有資料。

1. 用yield寫一個生成器函式

# coding:utf-8
# 自己造輪子,實現深度學習批量資料的讀取
import os import glob import numpy as np import cv2 def get_images(path): files = [] for ext in ['jpg', 'png', 'jpeg', 'JPG']: files.extend(glob.glob( os.path.join(path, '*.{}'.format(ext)))) return files def dataset(batch_size=2, path='/media/chenjun/data/1_deeplearning/7_ammeter_data/test'
): """ 寫一個讀取圖片的生成器 batch_size:批量大小 path:圖片路徑 """ # 1. 讀取所有圖片名字 image_list = get_images(path) index = np.arange(0, len(image_list)) while True: np.random.shuffle(index) images = [] image_names = [] for i in index: try: im_name = image_list[i] im = cv2.imread(im_name) # 讀取圖片 # 讀取相應圖片的標註資訊 # text_polys = fun1() images.append(im[:,:, ::-1].astype(np.float32)) # cv2讀取圖片的順序為BGR,轉換成RGB格式 image_names.append(im_name) if len(images) == batch_size: yield images, image_names # 採用函式生成器,生成一個可迭代物件 images = [] image_names = [] except Exception as e: import traceback traceback.print_exc() continue # 所有圖片已經讀完一遍,跳出for迴圈,再打亂圖片的順序進行第二次讀取

2. 使用muitlprocessing加速檔案讀取速度

<!-- 採用正常模式進行圖片讀取,讀取100個batch -->
import time
mydataset = dataset()
start = time.time()
for _ in range(100):
    im, im_name = next(mydataset)
#     print(im_name)
print('use time:{}'.format(time.time() - start))
>>>  use time:0.16786599159240723


<!-- 採用muitlprocessing模式進行圖片讀取,讀取100個batch -->
import multiprocessing
def data_generator(data, q):
    for _ in range(100):                # 迴圈多少次
        generator_output = next(data)
        q.put(generator_output)

q = multiprocessing.Queue()
start2 = time.time()
thread = multiprocessing.Process(target=data_generator, args=(dataset(), q))
thread.start()              # 多程序開始讀取圖片
print('mulprocess time is:{}'.format(time.time() - start2))
>>>  mulprocess time is:0.002292633056640625

可以看到讀取100個batch,時間提高了80倍。
同時,一般的深度學習框架都會使用幾個多程序處理上面的功能。
eg:

for _ in range(workers):
                if self._use_multiprocessing:
                    # Reset random seed else all children processes
                    # share the same seed
                    np.random.seed(self.random_seed)
                    thread = multiprocessing.Process(target=data_generator_task)
                    

網上的資料顯示threading的效率沒有muitlprocessing高,這裡就不測試了。

reference

[1] 莫煩python
[2] argman/EAST