Tensorflow中使用tfrecord方式讀取資料
前言
本部落格預設讀者對神經網路與Tensorflow有一定了解,對其中的一些術語不再做具體解釋。並且本部落格主要以圖片資料為例進行介紹,如有錯誤,敬請斧正。
使用Tensorflow訓練神經網路時,我們可以用多種方式來讀取自己的資料。如果資料集比較小,而且記憶體足夠大,可以選擇直接將所有資料讀進記憶體,然後每次取一個batch的資料出來。如果資料較多,可以每次直接從硬碟中進行讀取,不過這種方式的讀取效率就比較低了。此篇部落格就主要講一下Tensorflow官方推薦的一種較為高效的資料讀取方式——tfrecord。
從巨集觀來講,tfrecord其實是一種資料儲存形式。使用tfrecord時,實際上是先讀取原生資料,然後轉換成tfrecord格式,再儲存在硬碟上。而使用時,再把資料從相應的tfrecord檔案中解碼讀取出來。那麼使用tfrecord和直接從硬碟讀取原生資料相比到底有什麼優勢呢?其實,Tensorflow有和tfrecord配套的一些函式,可以加快資料的處理。實際讀取tfrecord資料時,先以相應的tfrecord檔案為引數,建立一個輸入佇列,這個佇列有一定的容量(視具體硬體限制,使用者可以設定不同的值),在一部分資料出佇列時,tfrecord中的其他資料就可以通過預取進入佇列,並且這個過程和網路的計算是獨立進行的。也就是說,網路每一個iteration的訓練不必等待資料佇列準備好再開始,佇列中的資料始終是充足的,而往佇列中填充資料時,也可以使用多執行緒加速。
下面,本文將從以下4個方面對tfrecord進行介紹:
- tfrecord格式簡介
- 利用自己的資料生成tfrecord檔案
- 從tfrecord檔案讀取資料
- 例項測試
1. tfrecord格式簡介
tfecord檔案中的資料是通過tf.train.Example Protocol Buffer的格式儲存的,下面是tf.train.Example的定義
message Example {
Features features = 1;
};
message Features{
map<string,Feature> featrue = 1;
};
message Feature{
oneof kind{
BytesList bytes_list = 1 ;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};
從上述程式碼可以看出,tf.train.Example 的資料結構很簡單。tf.train.Example中包含了一個從屬性名稱到取值的字典,其中屬性名稱為一個字串,屬性的取值可以為字串(BytesList ),浮點數列表(FloatList )或整數列表(Int64List )。例如我們可以將圖片轉換為字串進行儲存,影象對應的類別標號作為整數儲存,而用於迴歸任務的ground-truth可以作為浮點數儲存。通過後面的程式碼我們會對tfrecord的這種字典形式有更直觀的認識。
2. 利用自己的資料生成tfrecord檔案
先上一段程式碼,然後我再針對程式碼進行相關介紹。
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from scipy import misc
import scipy.io as sio
def _bytes_feature(value):
return tf.train.Feature(bytes_list = tf.train.BytesList(value=[value]))
def _int64_feature(value):
return tf.train.Feature(int64_list = tf.train.Int64List(value=[value]))
root_path = '/mount/temp/WZG/Multitask/Data/'
tfrecords_filename = root_path + 'tfrecords/train.tfrecords'
writer = tf.python_io.TFRecordWriter(tfrecords_filename)
height = 300
width = 300
meanfile = sio.loadmat(root_path + 'mats/mean300.mat')
meanvalue = meanfile['mean']
txtfile = root_path + 'txt/train.txt'
fr = open(txtfile)
for i in fr.readlines():
item = i.split()
img = np.float64(misc.imread(root_path + '/images/train_images/' + item[0]))
img = img - meanvalue
maskmat = sio.loadmat(root_path + '/mats/train_mats/' + item[1])
mask = np.float64(maskmat['seg_mask'])
label = int(item[2])
img_raw = img.tostring()
mask_raw = mask.tostring()
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(height),
'width': _int64_feature(width),
'name': _bytes_feature(item[0]),
'image_raw': _bytes_feature(img_raw),
'mask_raw': _bytes_feature(mask_raw),
'label': _int64_feature(label)}))
writer.write(example.SerializeToString())
writer.close()
fr.close()
程式碼中前兩個函式(_bytes_feature和_int64_feature)是將我們的原生資料進行轉換用的,尤其是圖片要轉換成字串再進行儲存。這兩個函式的定義來自官方的示例。
接下來,我定義了資料的(路徑-label檔案)txtfile,它大概長這個樣子:
這裡稍微囉嗦下,介紹一下我的實驗內容。我做的是一個multi-task的實驗,一支task做分割,一支task做分類。所以txtfile中每一行是一個樣本,每個樣本又包含3項,第一項為圖片名稱,第二項為相應的ground-truth segmentation mask的名稱,第三項是圖片的標籤。(txtfile中內容形式無所謂,只要能讀到想讀的資料就可以)
接著回到主題繼續講程式碼,之後我又定義了即將生成的tfrecord的檔案路徑和名稱,即tfrecord_filename,還有一個writer,這個writer是進行寫操作用的。
接下來是圖片的高度、寬度以及我事先在整個資料集上計算好的影象均值檔案。高度、寬度其實完全沒必要引入,這裡只是為了說明tfrecord的生成而寫的。而均值檔案是為了對影象進行事先的去均值化操作而引入的,在大多數機器學習任務中,影象去均值化對提高演算法的效能還是很有幫助的。
最後就是根據txtfile中的每一行進行相關資料的讀取、轉換以及tfrecord的生成了。首先是根據圖片路徑讀取圖片內容,然後影象減去之前讀入的均值,接著根據segmentation mask的路徑讀取mask(如果只是影象分類任務,那麼就不會有這些額外的mask),txtfile中的label讀出來是string格式,這裡要轉換成int。然後影象和mask資料也要用相應的tosring函式轉換成string。
真正的核心是下面這一小段程式碼:
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(height),
'width': _int64_feature(width),
'name': _bytes_feature(item[0]),
'image_raw': _bytes_feature(img_raw),
'mask_raw': _bytes_feature(mask_raw),
'label': _int64_feature(label)}))
writer.write(example.SerializeToString())
這裡很好地體現了tfrecord的字典特性,tfrecord中每一個樣本都是一個小字典,這個字典可以包含任意多個鍵值對。比如我這裡就儲存了圖片的高度、寬度、圖片名稱、圖片內容、mask內容以及圖片的label。對於我的任務來說,其實height、width、name都不是必需的,這裡僅僅是為了展示。鍵值對的鍵全都是字串,鍵起什麼名字都可以,只要能方便以後使用就可以。
定義好一個example後就可以用之前的writer來把它真正寫入tfrecord檔案了,這其實就跟把一行內容寫入一個txt檔案一樣。程式碼的最後就是writer和txt檔案物件的關閉了。
最後在指定資料夾下,就得到了指定名字的tfrecord檔案,如下所示:
需要注意的是,生成的tfrecord檔案比原生資料的大小還要大,這是正常現象。這種現象可能是因為圖片一般都儲存為jpg等壓縮格式,而tfrecord檔案儲存的是解壓後的資料。
3. 從tfrecord檔案讀取資料
還是程式碼先行。
from scipy import misc
import tensorflow as tf
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt
root_path = '/mount/temp/WZG/Multitask/Data/'
tfrecord_filename = root_path + 'tfrecords/test.tfrecords'
def read_and_decode(filename_queue, random_crop=False, random_clip=False, shuffle_batch=True):
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'name': tf.FixedLenFeature([], tf.string),
'image_raw': tf.FixedLenFeature([], tf.string),
'mask_raw': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64)
})
image = tf.decode_raw(features['image_raw'], tf.float64)
image = tf.reshape(image, [300,300,3])
mask = tf.decode_raw(features['mask_raw'], tf.float64)
mask = tf.reshape(mask, [300,300])
name = features['name']
label = features['label']
width = features['width']
height = features['height']
# if random_crop:
# image = tf.random_crop(image, [227, 227, 3])
# else:
# image = tf.image.resize_image_with_crop_or_pad(image, 227, 227)
# if random_clip:
# image = tf.image.random_flip_left_right(image)
if shuffle_batch:
images, masks, names, labels, widths, heights = tf.train.shuffle_batch([image, mask, name, label, width, height],
batch_size=4,
capacity=8000,
num_threads=4,
min_after_dequeue=2000)
else:
images, masks, names, labels, widths, heights = tf.train.batch([image, mask, name, label, width, height],
batch_size=4,
capacity=8000,
num_threads=4)
return images, masks, names, labels, widths, heights
讀取tfrecord檔案中的資料主要是應用read_and_decode()這個函式,可以看到其中有個引數是filename_queue,其實我們並不是直接從tfrecord檔案進行讀取,而是要先利用tfrecord檔案建立一個輸入佇列,如本文開頭所述那樣。關於這點,到後面真正的測試程式碼我再介紹。
在read_and_decode()中,一上來我們先定義一個reader物件,然後使用reader得到serialized_example,這是一個序列化的物件,接著使用tf.parse_single_example()函式對此物件進行初步解析。從程式碼中可以看到,解析時,我們要用到之前定義的那些鍵。對於影象、mask這種轉換成字串的資料,要進一步使用tf.decode_raw()函式進行解析,這裡要特別注意函式裡的第二個引數,也就是解析後的型別。之前圖片在轉成字串之前是什麼型別的資料,那麼這裡的引數就要填成對應的型別,否則會報錯。對於name、label、width、height這樣的資料就不用再解析了,我們得到的features物件就是個字典,利用鍵就可以拿到對應的值,如程式碼所示。
我註釋掉的部分是用來做資料增強的,比如隨機的裁剪與翻轉,除了這兩種,其他形式的資料增強也可以寫在這裡,讀者可以根據自己的需要,決定是否使用各種資料增強方式。
函式最後就是使用解析出來的資料生成batch了。Tensorflow提供了兩種方式,一種是shuffle_batch,這種主要是用在訓練中,隨機選取樣本組成batch。另外一種就是按照資料在tfrecord中的先後順序生成batch。對於生成batch的函式,建議讀者去官網檢視API文件進行細緻瞭解。這裡稍微做一下介紹,batch的大小,即batch_size就需要在生成batch的函式裡指定。另外,capacity引數指定資料佇列一次效能放多少個樣本,此引數設定什麼值需要視硬體環境而定。num_threads引數指定可以開啟幾個執行緒來向資料佇列中填充資料,如果硬體效能不夠強,最好設小一點,否則容易崩。
4. 例項測試
實際使用時先指定好我們需要使用的tfrecord檔案:
root_path = '/mount/temp/WZG/Multitask/Data/'
tfrecord_filename = root_path + 'tfrecords/test.tfrecords'
然後用該tfrecord檔案建立一個輸入佇列:
filename_queue = tf.train.string_input_producer([tfrecord_filename],
num_epochs=3)
這裡有個引數是num_epochs,指定好之後,Tensorflow自然知道如何讀取資料,保證在遍歷資料集的一個epoch中樣本不會重複,也知道資料讀取何時應該停止。
下面我將完整的測試程式碼貼出:
def test_run(tfrecord_filename):
filename_queue = tf.train.string_input_producer([tfrecord_filename],
num_epochs=3)
images, masks, names, labels, widths, heights = read_and_decode(filename_queue)
init_op = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
meanfile = sio.loadmat(root_path + 'mats/mean300.mat')
meanvalue = meanfile['mean']
with tf.Session() as sess:
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(1):
imgs, msks, nms, labs, wids, heis = sess.run([images, masks, names, labels, widths, heights])
print 'batch' + str(i) + ': '
#print type(imgs[0])
for j in range(4):
print nms[j] + ': ' + str(labs[j]) + ' ' + str(wids[j]) + ' ' + str(heis[j])
img = np.uint8(imgs[j] + meanvalue)
msk = np.uint8(msks[j])
plt.subplot(4,2,j*2+1)
plt.imshow(img)
plt.subplot(4,2,j*2+2)
plt.imshow(msk, vmin=0, vmax=5)
plt.show()
coord.request_stop()
coord.join(threads)
函式中接下來就是利用之前定義的read_and_decode()來得到一個batch的資料,此後我又讀入了均值檔案,這是因為之前做了去均值處理,如果要正常顯示圖片需要再把均值加回來。
再之後就是建立一個Tensorflow session,然後初始化物件。這些是Tensorflow基本操作,不再贅述。下面的這兩句程式碼非常重要,是讀取資料必不可少的。
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
然後是執行sess.run()拿到實際資料,之前只是相當於定義好了,並沒有得到真實數值。為了簡單起見,我在之後的迴圈裡只測試了一個batch的資料,關於tfrecord的標準使用我也建議讀者去官網的資料讀取部分看看示例。迴圈裡對資料的各種資訊進行了展示,結果如下:
從圖片的名字可以看出,資料的確是進行了shuffle的,標籤、寬度、高度、圖片本身以及對應的mask影象也全部展示出來了。
測試函式的最後,要使用以下兩句程式碼進行停止,就如同檔案需要close()一樣:
coord.request_stop()
coord.join(threads)