TensorFlow——TFRecords檔案
阿新 • • 發佈:2018-12-10
一、什麼是TFRecords檔案
TFRecords其實是一種二進位制檔案,雖然它不如其他格式好理解,但是它能更好的利用記憶體,更方便複製和移動,並且不需要單獨的標籤檔案。
使用步驟:
1)獲取資料
2)將資料填入到Example
協議記憶體塊(protocol buffer)
3)將協議記憶體塊序列化為字串, 並且通過tf.python_io.TFRecordWriter
寫入到TFRecords檔案。
- 檔案格式 *.tfrecords
二、Example結構解析
tf.train.Example
協議記憶體塊(protocol buffer)(協議記憶體塊包含了欄位Features
Features
包含了一個Feature
欄位Feature
中包含要寫入的資料、並指明資料型別。- 這是一個樣本的結構,批資料需要迴圈存入這樣的結構
example = tf.train.Example(features=tf.train.Features(feature={ "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])), "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])), }))
- tf.train.Example(features=None)
- 寫入tfrecords檔案
- features:tf.train.Features型別的特徵例項
- return:example格式協議塊
- tf.train.Features(feature=None)
- 構建每個樣本的資訊鍵值對
- feature:字典資料,key為要儲存的名字
- value為tf.train.Feature例項
- return:Features型別
- tf.train.Feature(options)
- options:例如
- bytes_list=tf.train. BytesList(value=[Bytes])
- int64_list=tf.train. Int64List(value=[Value])
- 支援存入的型別如下
- tf.train.Int64List(value=[Value])
- tf.train.BytesList(value=[Bytes])
- tf.train.FloatList(value=[value])
- options:例如
這種結構很好地實現了資料和標籤(訓練的類別標籤)或者其他屬性資料儲存在同一個檔案中
三、案例:CIFAR10資料存入TFRecords檔案
1.分析
-
構造儲存例項,tf.python_io.TFRecordWriter(path)
- 寫入tfrecords檔案
- path:TFRecords檔案的路徑
- return:寫檔案
- method方法
- write(record):向檔案中寫入一個example
- close():關閉檔案寫入器
- method方法
-
迴圈將資料填入到
Example
協議記憶體塊(protocol buffer)
2.程式碼
對於每一個圖片樣本資料,都需要寫入到example當中,所以這裡需要取出每一樣本進行構造存入
def write_to_tfrecords(self, image_batch, label_batch):
"""
將資料存進tfrecords,方便管理每個樣本的屬性
:param image_batch: 特徵值
:param label_batch: 目標值
:return: None
"""
# 1、構造tfrecords的儲存例項
writer = tf.python_io.TFRecordWriter(FLAGS.tfrecords_dir)
# 2、迴圈將每個樣本寫入到檔案當中
for i in range(10):
# 一個樣本一個樣本的處理寫入
# 準備特徵值,特徵值必須是bytes型別 呼叫tostring()函式
# [10, 32, 32, 3] ,在這裡避免tensorflow的坑,取出來的不是真正的值,而是型別,所以要執行結果才能存入
# 出現了eval,那就要在會話當中去執行該行數
image = image_batch[i].eval().tostring()
# 準備目標值,目標值是一個Int型別
# eval()-->[6]--->6
label = label_batch[i].eval()[0]
# 繫結每個樣本的屬性
example = tf.train.Example(features=tf.train.Features(feature={
"image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
}))
# 寫入每個樣本的example
writer.write(example.SerializeToString())
# 檔案需要關閉
writer.close()
return None
# 開啟會話列印內容
with tf.Session() as sess:
# 建立執行緒協調器
coord = tf.train.Coordinator()
# 開啟子執行緒去讀取資料
# 返回子執行緒例項
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
# 獲取樣本資料去訓練
print(sess.run([image_batch, label_batch]))
# 存入資料
cr.write_to_tfrecords(image_batch, label_batch )
# 關閉子執行緒,回收
coord.request_stop()
coord.join(threads)
四、讀取TFRecords檔案API
讀取這種檔案整個過程與其他檔案一樣,只不過需要有個解析Example的步驟。從TFRecords檔案中讀取資料, 可以使用tf.TFRecordReader
的tf.parse_single_example
解析器。這個操作可以將Example
協議記憶體塊(protocol buffer)解析為張量。
# 多瞭解析example的一個步驟
feature = tf.parse_single_example(values, features={
"image": tf.FixedLenFeature([], tf.string),
"label": tf.FixedLenFeature([], tf.int64)
})
-
tf.parse_single_example(serialized, features=None, name=None)
- 解析一個單一的Example原型
- serialized:標量字串Tensor,一個序列化的Example
- features:dict字典資料,鍵為讀取的名字,值為FixedLenFeature
- return:一個鍵值對組成的字典,鍵為讀取的名字
-
tf.FixedLenFeature(shape, dtype)
- shape:輸入資料的形狀,一般不指定,為空列表
- dtype:輸入資料型別,與儲存進檔案的型別要一致
- 型別只能是float32, int64, string
五、案例:讀取CIFAR的TFRecords檔案
1.分析
- 使用tf.train.string_input_producer構造檔案佇列
- tf.TFRecordReader 讀取TFRecords資料並進行解析
- tf.parse_single_example進行解析
- tf.decode_raw解碼
- 型別是bytes型別需要解碼
- 其他型別不需要
- 處理圖片資料形狀以及資料型別,加入批處理佇列
- 開啟會話執行緒執行
2.程式碼
def read_tfrecords(self):
"""
讀取tfrecords的資料
:return: None
"""
# 1、構造檔案佇列
file_queue = tf.train.string_input_producer(["./tmp/cifar.tfrecords"])
# 2、構造tfrecords讀取器,讀取佇列
reader = tf.TFRecordReader()
# 預設也是隻讀取一個樣本
key, values = reader.read(file_queue)
# tfrecords
# 多瞭解析example的一個步驟
feature = tf.parse_single_example(values, features={
"image": tf.FixedLenFeature([], tf.string),
"label": tf.FixedLenFeature([], tf.int64)
})
# 取出feature裡面的特徵值和目標值
# 通過鍵值對獲取
image = feature["image"]
label = feature["label"]
# 3、解碼操作
# 對於image是一個bytes型別,所以需要decode_raw去解碼成uint8張量
# 對於Label:本身是一個int型別,不需要去解碼
image = tf.decode_raw(image, tf.uint8)
print(image, label)
# # 從原來的[32,32,3]的bytes形式直接變成[32,32,3]
# 不存在一開始我們的讀取RGB的問題
# 處理image的形狀和型別
image_reshape = tf.reshape(image, [self.height, self.width, self.channel])
# 處理label的形狀和型別
label_cast = tf.cast(label, tf.int32)
print(image_reshape, label_cast)
# 4、批處理操作
image_batch, label_batch = tf.train.batch([image_reshape, label_cast], batch_size=10, num_threads=1, capacity=10)
print(image_batch, label_batch)
return image_batch, label_batch
# 從tfrecords檔案讀取資料
image_batch, label_batch = cr.read_tfrecords()
# 開啟會話列印內容
with tf.Session() as sess:
# 建立執行緒協調器
coord = tf.train.Coordinator()
完整程式碼:
import tensorflow as tf
import os
class Cifar(object):
# 初始化
def __init__(self):
# 影象的大小
self.height = 32
self.width = 32
self.channels = 3
# 影象的位元組數
self.label_bytes = 1
self.image_bytes = self.height * self.width * self.channels
self.bytes = self.label_bytes + self.image_bytes
def read_and_decode(self, file_list):
# 讀取二進位制檔案
# print("read_and_decode:\n", file_list)
# 1、構造檔名佇列
file_queue = tf.train.string_input_producer(file_list)
# 2、構造二進位制檔案閱讀器
reader = tf.FixedLengthRecordReader(self.bytes)
key, value = reader.read(file_queue)
print("key:\n", key)
print("value:\n", value)
# 3、解碼
decoded = tf.decode_raw(value, tf.uint8)
print("decoded:\n", decoded)
# 4、基本的資料處理
# 切片處理,把標籤值和特徵值分開
label = tf.slice(decoded, [0], [self.label_bytes])
image = tf.slice(decoded, [self.label_bytes], [self.image_bytes])
print("label:\n", label)
print("image:\n", image)
# 改變影象的形狀
image_reshaped = tf.reshape(image, [self.channels, self.height, self.width])
# 轉置
image_transposed = tf.transpose(image_reshaped, [1, 2, 0])
print("image_transposed:\n", image_transposed)
# 型別轉換
label_cast = tf.cast(label, tf.float32)
image_cast = tf.cast(image_transposed, tf.float32)
# 5、批處理
label_batch, image_batch = tf.train.batch([label_cast, image_cast], batch_size=10, num_threads=1, capacity=10)
return label_batch, image_batch
def write_to_tfrecords(self, label_batch, image_batch):
# 進行型別轉換,轉成tf.uint8
# 為了節省空間
label_batch = tf.cast(label_batch, tf.uint8)
image_batch = tf.cast(image_batch, tf.uint8)
# 構造tfrecords儲存器
with tf.python_io.TFRecordWriter("./cifar.tfrecords") as writer:
for i in range(10):
label = label_batch[i].eval()[0]
image = image_batch[i].eval().tostring()
print("tfrecords_label:\n", label)
print("tfrecords_image:\n", image, type(image))
# 構造example協議塊
example = tf.train.Example(features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train. Int64List(value=[label])),
"image": tf.train.Feature(bytes_list=tf.train. BytesList(value=[image]))
}))
# 寫入序列化後的example
writer.write(example.SerializeToString())
def read_tfrecords(self):
# 讀取tfrecords檔案
# 1、構造檔名佇列
file_queue = tf.train.string_input_producer(["cifar.tfrecords"])
# 2、構造tfrecords閱讀器
reader = tf.TFRecordReader()
key, value = reader.read(file_queue)
# 3、解析example協議塊
example = tf.parse_single_example(value, features={
"label": tf.FixedLenFeature(shape=[], dtype=tf.int64),
"image": tf.FixedLenFeature(shape=[], dtype=tf.string)
})
label = example["label"]
image = example["image"]
print("read_tfrecords_label:\n", label)
print("read_tfrecords_image:\n", image)
# 4、解碼
image_decoded = tf.decode_raw(image, tf.uint8)
print("read_tfrecords_image_decoded:\n", image_decoded)
# 5、基本的資料處理
# 調整影象形狀
image_reshaped = tf.reshape(image_decoded, [self.height, self.width, self.channels])
# 轉換型別
image_cast = tf.cast(image_reshaped, tf.float32)
label_cast = tf.cast(label, tf.float32)
print("read_records_image_cast:\n", image_cast)
print("read_records_label_cast:\n", label_cast)
# 6、批處理
label_batch, image_batch = tf.train.batch([label_cast, image_cast], batch_size=10, num_threads=1, capacity=10)
return label_batch, image_batch
if __name__ == "__main__":
# 構造檔名列表
file_name = os.listdir("./cifar-10-batches-bin")
print("file_name:\n", file_name)
file_list = [os.path.join("./cifar-10-batches-bin/", file) for file in file_name if file[-3:] == "bin"]
print("file_list:\n", file_list)
# 呼叫讀取二進位制檔案的方法
cf = Cifar()
# label, image = cf.read_and_decode(file_list)
label, image = cf.read_tfrecords()
# 開啟會話
with tf.Session() as sess:
# 建立執行緒協調器
coord = tf.train.Coordinator()
# 建立執行緒
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
# 列印結果
print("label:\n", sess.run(label))
print("image:\n", sess.run(image))
# cf.write_to_tfrecords(label, image)
# 回收資源
coord.request_stop()
coord.join(threads)