1. 程式人生 > >TensorFlow——TFRecords檔案

TensorFlow——TFRecords檔案

一、什麼是TFRecords檔案

TFRecords其實是一種二進位制檔案,雖然它不如其他格式好理解,但是它能更好的利用記憶體,更方便複製和移動,並且不需要單獨的標籤檔案

使用步驟:

1)獲取資料

2)將資料填入到Example協議記憶體塊(protocol buffer)

3)將協議記憶體塊序列化為字串, 並且通過tf.python_io.TFRecordWriter 寫入到TFRecords檔案。

  • 檔案格式 *.tfrecords

二、Example結構解析

  • tf.train.Example 協議記憶體塊(protocol buffer)(協議記憶體塊包含了欄位 Features
    )
  • Features包含了一個Feature欄位
  • Feature中包含要寫入的資料、並指明資料型別。
    • 這是一個樣本的結構,批資料需要迴圈存入這樣的結構
 example = tf.train.Example(features=tf.train.Features(feature={
                "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
            }))
  • tf.train.Example(features=None)
    • 寫入tfrecords檔案
    • features:tf.train.Features型別的特徵例項
    • return:example格式協議塊
  • tf.train.Features(feature=None)
    • 構建每個樣本的資訊鍵值對
    • feature:字典資料,key為要儲存的名字
    • value為tf.train.Feature例項
    • return:Features型別
  • tf.train.Feature(options)
    • options:例如
      • bytes_list=tf.train. BytesList(value=[Bytes])
      • int64_list=tf.train. Int64List(value=[Value])
    • 支援存入的型別如下
    • tf.train.Int64List(value=[Value])
    • tf.train.BytesList(value=[Bytes])
    • tf.train.FloatList(value=[value])

這種結構很好地實現了資料和標籤(訓練的類別標籤)或者其他屬性資料儲存在同一個檔案中

三、案例:CIFAR10資料存入TFRecords檔案

1.分析

  • 構造儲存例項,tf.python_io.TFRecordWriter(path)

    • 寫入tfrecords檔案
    • path:TFRecords檔案的路徑
    • return:寫檔案
      • method方法
        • write(record):向檔案中寫入一個example
        • close():關閉檔案寫入器
  • 迴圈將資料填入到Example協議記憶體塊(protocol buffer)

2.程式碼

對於每一個圖片樣本資料,都需要寫入到example當中,所以這裡需要取出每一樣本進行構造存入

def write_to_tfrecords(self, image_batch, label_batch):
    """
        將資料存進tfrecords,方便管理每個樣本的屬性
        :param image_batch: 特徵值
        :param label_batch: 目標值
        :return: None
        """
    # 1、構造tfrecords的儲存例項
    writer = tf.python_io.TFRecordWriter(FLAGS.tfrecords_dir)

    # 2、迴圈將每個樣本寫入到檔案當中
    for i in range(10):

        # 一個樣本一個樣本的處理寫入
        # 準備特徵值,特徵值必須是bytes型別 呼叫tostring()函式
        # [10, 32, 32, 3] ,在這裡避免tensorflow的坑,取出來的不是真正的值,而是型別,所以要執行結果才能存入
        # 出現了eval,那就要在會話當中去執行該行數
        image = image_batch[i].eval().tostring()

        # 準備目標值,目標值是一個Int型別
        # eval()-->[6]--->6
        label = label_batch[i].eval()[0]

        # 繫結每個樣本的屬性
        example = tf.train.Example(features=tf.train.Features(feature={
            "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
            "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
        }))

        # 寫入每個樣本的example
        writer.write(example.SerializeToString())

        # 檔案需要關閉
        writer.close()
        return None

    # 開啟會話列印內容
    with tf.Session() as sess:
        # 建立執行緒協調器
        coord = tf.train.Coordinator()

        # 開啟子執行緒去讀取資料
        # 返回子執行緒例項
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        # 獲取樣本資料去訓練
        print(sess.run([image_batch, label_batch]))

        # 存入資料
        cr.write_to_tfrecords(image_batch, label_batch )

        # 關閉子執行緒,回收
        coord.request_stop()

        coord.join(threads)

四、讀取TFRecords檔案API

讀取這種檔案整個過程與其他檔案一樣,只不過需要有個解析Example的步驟。從TFRecords檔案中讀取資料, 可以使用tf.TFRecordReadertf.parse_single_example解析器。這個操作可以將Example協議記憶體塊(protocol buffer)解析為張量。

# 多瞭解析example的一個步驟
feature = tf.parse_single_example(values, features={
    "image": tf.FixedLenFeature([], tf.string),
    "label": tf.FixedLenFeature([], tf.int64)
})
  • tf.parse_single_example(serialized, features=None, name=None)

    • 解析一個單一的Example原型
    • serialized:標量字串Tensor,一個序列化的Example
    • features:dict字典資料,鍵為讀取的名字,值為FixedLenFeature
    • return:一個鍵值對組成的字典,鍵為讀取的名字
  • tf.FixedLenFeature(shape, dtype)

    • shape:輸入資料的形狀,一般不指定,為空列表
    • dtype:輸入資料型別,與儲存進檔案的型別要一致
    • 型別只能是float32, int64, string

五、案例:讀取CIFAR的TFRecords檔案

1.分析

  • 使用tf.train.string_input_producer構造檔案佇列
  • tf.TFRecordReader 讀取TFRecords資料並進行解析
    • tf.parse_single_example進行解析
  • tf.decode_raw解碼
    • 型別是bytes型別需要解碼
    • 其他型別不需要
  • 處理圖片資料形狀以及資料型別,加入批處理佇列
  • 開啟會話執行緒執行

2.程式碼

def read_tfrecords(self):
    """
        讀取tfrecords的資料
        :return: None
        """
    # 1、構造檔案佇列
    file_queue = tf.train.string_input_producer(["./tmp/cifar.tfrecords"])

    # 2、構造tfrecords讀取器,讀取佇列
    reader = tf.TFRecordReader()

    # 預設也是隻讀取一個樣本
    key, values = reader.read(file_queue)

    # tfrecords
    # 多瞭解析example的一個步驟
    feature = tf.parse_single_example(values, features={
        "image": tf.FixedLenFeature([], tf.string),
        "label": tf.FixedLenFeature([], tf.int64)
    })

    # 取出feature裡面的特徵值和目標值
    # 通過鍵值對獲取
    image = feature["image"]

    label = feature["label"]

    # 3、解碼操作
    # 對於image是一個bytes型別,所以需要decode_raw去解碼成uint8張量
    # 對於Label:本身是一個int型別,不需要去解碼
    image = tf.decode_raw(image, tf.uint8)

    print(image, label)

    # # 從原來的[32,32,3]的bytes形式直接變成[32,32,3]
    # 不存在一開始我們的讀取RGB的問題
    # 處理image的形狀和型別
    image_reshape = tf.reshape(image, [self.height, self.width, self.channel])

    # 處理label的形狀和型別
    label_cast = tf.cast(label, tf.int32)

    print(image_reshape, label_cast)

    # 4、批處理操作
    image_batch, label_batch = tf.train.batch([image_reshape, label_cast], batch_size=10, num_threads=1, capacity=10)

    print(image_batch, label_batch)
    return image_batch, label_batch

# 從tfrecords檔案讀取資料
image_batch, label_batch = cr.read_tfrecords()

# 開啟會話列印內容
with tf.Session() as sess:
    # 建立執行緒協調器
    coord = tf.train.Coordinator()

完整程式碼:

import tensorflow as tf
import os


class Cifar(object):

    # 初始化
    def __init__(self):
        # 影象的大小
        self.height = 32
        self.width = 32
        self.channels = 3

        # 影象的位元組數
        self.label_bytes = 1
        self.image_bytes = self.height * self.width * self.channels
        self.bytes = self.label_bytes + self.image_bytes

    def read_and_decode(self, file_list):
        # 讀取二進位制檔案
        # print("read_and_decode:\n", file_list)
        # 1、構造檔名佇列
        file_queue = tf.train.string_input_producer(file_list)

        # 2、構造二進位制檔案閱讀器
        reader = tf.FixedLengthRecordReader(self.bytes)
        key, value = reader.read(file_queue)

        print("key:\n", key)
        print("value:\n", value)
        # 3、解碼
        decoded = tf.decode_raw(value, tf.uint8)
        print("decoded:\n", decoded)

        # 4、基本的資料處理
        # 切片處理,把標籤值和特徵值分開
        label = tf.slice(decoded, [0], [self.label_bytes])
        image = tf.slice(decoded, [self.label_bytes], [self.image_bytes])

        print("label:\n", label)
        print("image:\n", image)
        # 改變影象的形狀
        image_reshaped = tf.reshape(image, [self.channels, self.height, self.width])
        # 轉置
        image_transposed = tf.transpose(image_reshaped, [1, 2, 0])
        print("image_transposed:\n", image_transposed)

        # 型別轉換
        label_cast = tf.cast(label, tf.float32)
        image_cast = tf.cast(image_transposed, tf.float32)

        # 5、批處理
        label_batch, image_batch = tf.train.batch([label_cast, image_cast], batch_size=10, num_threads=1, capacity=10)
        return label_batch, image_batch


    def write_to_tfrecords(self, label_batch, image_batch):
        # 進行型別轉換,轉成tf.uint8
        # 為了節省空間
        label_batch = tf.cast(label_batch, tf.uint8)
        image_batch = tf.cast(image_batch, tf.uint8)
        # 構造tfrecords儲存器
        with tf.python_io.TFRecordWriter("./cifar.tfrecords") as writer:
            for i in range(10):
                label = label_batch[i].eval()[0]
                image = image_batch[i].eval().tostring()
                print("tfrecords_label:\n", label)
                print("tfrecords_image:\n", image, type(image))
                # 構造example協議塊
                example = tf.train.Example(features=tf.train.Features(feature={
                    "label": tf.train.Feature(int64_list=tf.train. Int64List(value=[label])),
                    "image": tf.train.Feature(bytes_list=tf.train. BytesList(value=[image]))
                }))
                # 寫入序列化後的example
                writer.write(example.SerializeToString())


    def read_tfrecords(self):
        # 讀取tfrecords檔案
        # 1、構造檔名佇列
        file_queue = tf.train.string_input_producer(["cifar.tfrecords"])

        # 2、構造tfrecords閱讀器
        reader = tf.TFRecordReader()
        key, value = reader.read(file_queue)

        # 3、解析example協議塊
        example = tf.parse_single_example(value, features={
            "label": tf.FixedLenFeature(shape=[], dtype=tf.int64),
            "image": tf.FixedLenFeature(shape=[], dtype=tf.string)
        })
        label = example["label"]
        image = example["image"]
        print("read_tfrecords_label:\n", label)
        print("read_tfrecords_image:\n", image)

        # 4、解碼
        image_decoded = tf.decode_raw(image, tf.uint8)
        print("read_tfrecords_image_decoded:\n", image_decoded)

        # 5、基本的資料處理
        # 調整影象形狀
        image_reshaped = tf.reshape(image_decoded, [self.height, self.width, self.channels])
        # 轉換型別
        image_cast = tf.cast(image_reshaped, tf.float32)
        label_cast = tf.cast(label, tf.float32)
        print("read_records_image_cast:\n", image_cast)
        print("read_records_label_cast:\n", label_cast)

        # 6、批處理
        label_batch, image_batch = tf.train.batch([label_cast, image_cast], batch_size=10, num_threads=1, capacity=10)

        return label_batch, image_batch

if __name__ == "__main__":
    # 構造檔名列表
    file_name = os.listdir("./cifar-10-batches-bin")
    print("file_name:\n", file_name)
    file_list = [os.path.join("./cifar-10-batches-bin/", file) for file in file_name if file[-3:] == "bin"]
    print("file_list:\n", file_list)

    # 呼叫讀取二進位制檔案的方法
    cf = Cifar()
    # label, image = cf.read_and_decode(file_list)
    label, image = cf.read_tfrecords()

    # 開啟會話
    with tf.Session() as sess:
        # 建立執行緒協調器
        coord = tf.train.Coordinator()
        # 建立執行緒
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        # 列印結果
        print("label:\n", sess.run(label))
        print("image:\n", sess.run(image))

        # cf.write_to_tfrecords(label, image)
        # 回收資源
        coord.request_stop()
        coord.join(threads)