TensorFlow入門(十-I)tfrecord 固定維度資料讀寫
阿新 • • 發佈:2019-02-13
關於 tfrecord 的使用,分別介紹 tfrecord 進行三種不同型別資料的處理方法。
- 維度固定的 numpy 矩陣
- 可變長度的 序列 資料
- 圖片資料
在 tf1.3 及以後版本中,推出了新的 Dataset API, 之前趕實驗還沒研究,可能以後都不太會用下面的方式寫了。這些程式碼都是之前寫好的,因為註釋中都寫得比較清楚了,所以直接上程式碼。
tfrecord_1_numpy_writer.py
# -*- coding:utf-8 -*-
import tensorflow as tf
import numpy as np
from tqdm import tqdm
'''tfrecord 寫入資料.
將固定shape的矩陣寫入 tfrecord 檔案。這種形式的資料寫入 tfrecord 是最簡單的。
refer: http://blog.csdn.net/qq_16949707/article/details/53483493
'''
# **1.建立檔案,可以建立多個檔案,在讀取的時候只需要提供所有檔名列表就行了
writer1 = tf.python_io.TFRecordWriter('../data/test1.tfrecord')
writer2 = tf.python_io.TFRecordWriter('../data/test2.tfrecord' )
"""
有一點需要注意的就是我們需要把矩陣轉為陣列形式才能寫入
就是需要經過下面的 reshape 操作
在讀取的時候再 reshape 回原始的 shape 就可以了
"""
X = np.arange(0, 100).reshape([50, -1]).astype(np.float32)
y = np.arange(50)
for i in tqdm(xrange(len(X))): # **2.對於每個樣本
if i >= len(y) / 2:
writer = writer2
else:
writer = writer1
X_sample = X[i].tolist()
y_sample = y[i]
# **3.定義資料型別,按照這裡固定的形式寫,有float_list(好像只有32位), int64_list, bytes_list.
example = tf.train.Example(
features=tf.train.Features(
feature={'X': tf.train.Feature(float_list=tf.train.FloatList(value=X_sample)),
'y': tf.train.Feature(int64_list=tf.train.Int64List(value=[y_sample]))}))
# **4.序列化資料並寫入檔案中
serialized = example.SerializeToString()
writer.write(serialized)
print('Finished.')
writer1.close()
writer2.close()
tfrecord_1_numpy_reader.py
# -*- coding:utf-8 -*-
import tensorflow as tf
'''read data
從 tfrecord 檔案中讀取資料,對應資料的格式為固定shape的資料。
'''
# **1.把所有的 tfrecord 檔名列表寫入佇列中
filename_queue = tf.train.string_input_producer(['../data/test1.tfrecord', '../data/test2.tfrecord'], num_epochs=None,
shuffle=True)
# **2.建立一個讀取器
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
# **3.根據你寫入的格式對應說明讀取的格式
features = tf.parse_single_example(serialized_example,
features={
'X': tf.FixedLenFeature([2], tf.float32), # 注意如果不是標量,需要說明陣列長度
'y': tf.FixedLenFeature([], tf.int64)} # 而標量就不用說明
)
X_out = features['X']
y_out = features['y']
print(X_out)
print(y_out)
# **4.通過 tf.train.shuffle_batch 或者 tf.train.batch 函式讀取資料
"""
在shuffle_batch 函式中,有幾個引數的作用如下:
capacity: 佇列的容量,容量越大的話,shuffle 得就更加均勻,但是佔用記憶體也會更多
num_threads: 讀取程序數,程序越多,讀取速度相對會快些,根據個人配置決定
min_after_dequeue: 保證佇列中最少的資料量。
假設我們設定了佇列的容量C,在我們取走部分資料m以後,佇列中只剩下了 (C-m) 個數據。然後佇列會不斷補充資料進來,
如果後勤供應(CPU效能,執行緒數量)補充速度慢的話,那麼下一次取資料的時候,可能才補充了一點點,如果補充完後的資料個數少於
min_after_dequeue 的話,不能取走資料,得繼續等它補充超過 min_after_dequeue 個樣本以後才讓取走資料。
這樣做保證了佇列中混著足夠多的資料,從而才能保證 shuffle 取值更加隨機。
但是,min_after_dequeue 不能設定太大,否則補充時間很長,讀取速度會很慢。
"""
X_batch, y_batch = tf.train.shuffle_batch([X_out, y_out], batch_size=2,
capacity=200, min_after_dequeue=100, num_threads=2)
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
# **5.啟動佇列進行資料讀取
# 下面的 coord 是個執行緒協調器,把啟動佇列的時候加上執行緒協調器。
# 這樣,在資料讀取完畢以後,呼叫協調器把執行緒全部都關了。
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
y_outputs = list()
for i in xrange(5):
_X_batch, _y_batch = sess.run([X_batch, y_batch])
print('** batch %d' % i)
print('_X_batch:', _X_batch)
print('_y_batch:', _y_batch)
y_outputs.extend(_y_batch.tolist())
print(y_outputs)
# **6.最後記得把佇列關掉
coord.request_stop()
coord.join(threads)