1. 程式人生 > 程式設計 >TensorFlow dataset.shuffle、batch、repeat的使用詳解

TensorFlow dataset.shuffle、batch、repeat的使用詳解

直接看程式碼例子,有詳細註釋!!

import tensorflow as tf
import numpy as np


d = np.arange(0,60).reshape([6,10])

# 將array轉化為tensor
data = tf.data.Dataset.from_tensor_slices(d)

# 從data資料集中按順序抽取buffer_size個樣本放在buffer中,然後打亂buffer中的樣本
# buffer中樣本個數不足buffer_size,繼續從data資料集中安順序填充至buffer_size,
# 此時會再次打亂
data = data.shuffle(buffer_size=3)

# 每次從buffer中抽取4個樣本
data = data.batch(4)

# 將data資料集重複,其實就是2個epoch資料集
data = data.repeat(2)

# 構造獲取資料的迭代器
iters = data.make_one_shot_iterator()

# 每次從迭代器中獲取一批資料
batch = iters.get_next()

sess = tf.Session()

sess.run(batch)
# 資料集完成遍歷完之後,繼續抽取的話會報錯:OutOfRangeError
In [21]: d
Out[21]: 
array([[ 0,1,2,3,4,5,6,7,8,9],[10,11,12,13,14,15,16,17,18,19],[20,21,22,23,24,25,26,27,28,29],[30,31,32,33,34,35,36,37,38,39],[40,41,42,43,44,45,46,47,48,49],[50,51,52,53,54,55,56,57,58,59]])
In [22]: sess.run(batch)
Out[22]: 
array([[ 0,19]])

In [23]: sess.run(batch)
Out[23]: 
array([[40,59]])

從輸出結果可以看出:

shuffle是按順序將資料放入buffer裡面的;

當repeat函式在shuffle之後的話,是將一個epoch的資料集抽取完畢,再進行下一個epoch的。

那麼,當repeat函式在shuffle之前會怎麼樣呢?如下:

data = data.repeat(2)

data = data.shuffle(buffer_size=3)

data = data.batch(4)
In [25]: sess.run(batch)
Out[25]: 
array([[10,[ 0,49]])

In [26]: sess.run(batch)
Out[26]: 
array([[50,59],39]])

In [27]: sess.run(batch)
Out[27]: 
array([[10,49]])

可以看出,其實它就是先將資料集複製一遍,然後把兩個epoch當成同一個新的資料集,一直shuffle和batch下去。

以上這篇TensorFlow dataset.shuffle、batch、repeat的使用詳解就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。