自己造輪子:深度學習dataloader自己實現
阿新 • • 發佈:2018-12-03
自己造輪子:深度學習dataloader自己實現
**摘要:**因為計算機效能的限制,所有的深度學習框架都是採用批量隨機梯度下降,所以每次計算都要讀取batch_size的資料。這裡以自己實現的方式介紹深度學習框架實現批量讀取資料的原理,不涉及具體細節和一些邏輯,只注重大體流程和原理。
總體流程:
- 採用yield寫一個生成器函式實現批量圖片/標註資訊的讀取
- 採用multiprocessing/threading加速檔案讀取
- 時間對比
深度學習大體流程
for i in range(epoch):
data, lable = dataloader. next(batch_size=16) # 讀取batch_size的資料
output = model(data) # 前向傳播
loss = crition(output, label) # 求損失函式
loss.backward() # 反向傳播
在dataloader的時候,一般會採用多個程序(num_workers
)加快檔案I/O的速度,避免網路反向傳播過了,還沒有資料。
1. 用yield寫一個生成器函式
# coding:utf-8
# 自己造輪子,實現深度學習批量資料的讀取
import os
import glob
import numpy as np
import cv2
def get_images(path):
files = []
for ext in ['jpg', 'png', 'jpeg', 'JPG']:
files.extend(glob.glob(
os.path.join(path, '*.{}'.format(ext))))
return files
def dataset(batch_size=2, path='/media/chenjun/data/1_deeplearning/7_ammeter_data/test' ):
"""
寫一個讀取圖片的生成器
batch_size:批量大小
path:圖片路徑
"""
# 1. 讀取所有圖片名字
image_list = get_images(path)
index = np.arange(0, len(image_list))
while True:
np.random.shuffle(index)
images = []
image_names = []
for i in index:
try:
im_name = image_list[i]
im = cv2.imread(im_name) # 讀取圖片
# 讀取相應圖片的標註資訊
# text_polys = fun1()
images.append(im[:,:, ::-1].astype(np.float32)) # cv2讀取圖片的順序為BGR,轉換成RGB格式
image_names.append(im_name)
if len(images) == batch_size:
yield images, image_names # 採用函式生成器,生成一個可迭代物件
images = []
image_names = []
except Exception as e:
import traceback
traceback.print_exc()
continue # 所有圖片已經讀完一遍,跳出for迴圈,再打亂圖片的順序進行第二次讀取
2. 使用muitlprocessing加速檔案讀取速度
<!-- 採用正常模式進行圖片讀取,讀取100個batch -->
import time
mydataset = dataset()
start = time.time()
for _ in range(100):
im, im_name = next(mydataset)
# print(im_name)
print('use time:{}'.format(time.time() - start))
>>> use time:0.16786599159240723
<!-- 採用muitlprocessing模式進行圖片讀取,讀取100個batch -->
import multiprocessing
def data_generator(data, q):
for _ in range(100): # 迴圈多少次
generator_output = next(data)
q.put(generator_output)
q = multiprocessing.Queue()
start2 = time.time()
thread = multiprocessing.Process(target=data_generator, args=(dataset(), q))
thread.start() # 多程序開始讀取圖片
print('mulprocess time is:{}'.format(time.time() - start2))
>>> mulprocess time is:0.002292633056640625
可以看到讀取100個batch,時間提高了80倍。
同時,一般的深度學習框架都會使用幾個多程序處理上面的功能。
eg:
for _ in range(workers):
if self._use_multiprocessing:
# Reset random seed else all children processes
# share the same seed
np.random.seed(self.random_seed)
thread = multiprocessing.Process(target=data_generator_task)
網上的資料顯示threading的效率沒有muitlprocessing高,這裡就不測試了。
reference
[1] 莫煩python
[2] argman/EAST