1. 程式人生 > >TensorFlow入門(十-I)tfrecord 固定維度資料讀寫

TensorFlow入門(十-I)tfrecord 固定維度資料讀寫

關於 tfrecord 的使用,分別介紹 tfrecord 進行三種不同型別資料的處理方法。
- 維度固定的 numpy 矩陣
- 可變長度的 序列 資料
- 圖片資料

在 tf1.3 及以後版本中,推出了新的 Dataset API, 之前趕實驗還沒研究,可能以後都不太會用下面的方式寫了。這些程式碼都是之前寫好的,因為註釋中都寫得比較清楚了,所以直接上程式碼。

tfrecord_1_numpy_writer.py

# -*- coding:utf-8 -*- 

import tensorflow as tf
import numpy as np
from tqdm import
tqdm '''tfrecord 寫入資料. 將固定shape的矩陣寫入 tfrecord 檔案。這種形式的資料寫入 tfrecord 是最簡單的。 refer: http://blog.csdn.net/qq_16949707/article/details/53483493 ''' # **1.建立檔案,可以建立多個檔案,在讀取的時候只需要提供所有檔名列表就行了 writer1 = tf.python_io.TFRecordWriter('../data/test1.tfrecord') writer2 = tf.python_io.TFRecordWriter('../data/test2.tfrecord'
) """ 有一點需要注意的就是我們需要把矩陣轉為陣列形式才能寫入 就是需要經過下面的 reshape 操作 在讀取的時候再 reshape 回原始的 shape 就可以了 """ X = np.arange(0, 100).reshape([50, -1]).astype(np.float32) y = np.arange(50) for i in tqdm(xrange(len(X))): # **2.對於每個樣本 if i >= len(y) / 2: writer = writer2 else: writer = writer1 X_sample = X[i].tolist() y_sample = y[i] # **3.定義資料型別,按照這裡固定的形式寫,有float_list(好像只有32位), int64_list, bytes_list.
example = tf.train.Example( features=tf.train.Features( feature={'X': tf.train.Feature(float_list=tf.train.FloatList(value=X_sample)), 'y': tf.train.Feature(int64_list=tf.train.Int64List(value=[y_sample]))})) # **4.序列化資料並寫入檔案中 serialized = example.SerializeToString() writer.write(serialized) print('Finished.') writer1.close() writer2.close()

tfrecord_1_numpy_reader.py

# -*- coding:utf-8 -*- 

import tensorflow as tf

'''read data
從 tfrecord 檔案中讀取資料,對應資料的格式為固定shape的資料。
'''

# **1.把所有的 tfrecord 檔名列表寫入佇列中
filename_queue = tf.train.string_input_producer(['../data/test1.tfrecord', '../data/test2.tfrecord'], num_epochs=None,
                                                shuffle=True)
# **2.建立一個讀取器
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
# **3.根據你寫入的格式對應說明讀取的格式
features = tf.parse_single_example(serialized_example,
                                   features={
                                       'X': tf.FixedLenFeature([2], tf.float32),  # 注意如果不是標量,需要說明陣列長度
                                       'y': tf.FixedLenFeature([], tf.int64)}     # 而標量就不用說明
                                   )
X_out = features['X']
y_out = features['y']

print(X_out)
print(y_out)
# **4.通過 tf.train.shuffle_batch 或者 tf.train.batch 函式讀取資料
"""
在shuffle_batch 函式中,有幾個引數的作用如下:
capacity: 佇列的容量,容量越大的話,shuffle 得就更加均勻,但是佔用記憶體也會更多
num_threads: 讀取程序數,程序越多,讀取速度相對會快些,根據個人配置決定
min_after_dequeue: 保證佇列中最少的資料量。
   假設我們設定了佇列的容量C,在我們取走部分資料m以後,佇列中只剩下了 (C-m) 個數據。然後佇列會不斷補充資料進來,
   如果後勤供應(CPU效能,執行緒數量)補充速度慢的話,那麼下一次取資料的時候,可能才補充了一點點,如果補充完後的資料個數少於
   min_after_dequeue 的話,不能取走資料,得繼續等它補充超過 min_after_dequeue 個樣本以後才讓取走資料。
   這樣做保證了佇列中混著足夠多的資料,從而才能保證 shuffle 取值更加隨機。
   但是,min_after_dequeue 不能設定太大,否則補充時間很長,讀取速度會很慢。
"""
X_batch, y_batch = tf.train.shuffle_batch([X_out, y_out], batch_size=2,
                                          capacity=200, min_after_dequeue=100, num_threads=2)
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

# **5.啟動佇列進行資料讀取
# 下面的 coord 是個執行緒協調器,把啟動佇列的時候加上執行緒協調器。
# 這樣,在資料讀取完畢以後,呼叫協調器把執行緒全部都關了。
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
y_outputs = list()
for i in xrange(5):
    _X_batch, _y_batch = sess.run([X_batch, y_batch])
    print('** batch %d' % i)
    print('_X_batch:', _X_batch)
    print('_y_batch:', _y_batch)
    y_outputs.extend(_y_batch.tolist())
print(y_outputs)

# **6.最後記得把佇列關掉
coord.request_stop()
coord.join(threads)