1. 程式人生 > >TensorFlow——二進位制資料讀取

TensorFlow——二進位制資料讀取

一、CIFAR10二進位制資料集介紹

https://www.cs.toronto.edu/~kriz/cifar.html
  • 二進位制版本資料檔案

二進位制版本包含檔案data_batch_1.bin,data_batch_2.bin,...,data_batch_5.bin以及test_batch.bin

。這些檔案中的每一個格式如下,資料中每個樣本包含了特徵值和目標值:

<1×標籤> <3072×畫素> 
... 
<1×標籤> <3072×畫素>

第一個位元組是第一個影象的標籤,它是一個0-9範圍內的數字。接下來的3072個位元組是影象畫素的值。前1024個位元組是紅色通道值,下1024個綠色,最後1024個藍色。

值以行優先順序儲存,因此前32個位元組是影象第一行的紅色通道值。 每個檔案都包含10000個這樣的3073位元組的“行”影象,但沒有任何分隔行的限制。因此每個檔案應該完全是30730000位元組長。

二、CIFAR10 二進位制資料讀取

1.分析

  • 構造檔案佇列
  • 讀取二進位制資料並進行解碼
  • 處理圖片資料形狀以及資料型別,批處理返回
  • 開啟會話執行緒執行

2.程式碼

  • 定義CIFAR類,設定圖片相關的屬性
class CifarRead(object):
    """
    二進位制檔案的讀取,tfrecords儲存讀取
    """

    def __init__(self):
        # 定義一些圖片的屬性
        self.height = 32
        self.width = 32
        self.channel = 3

        self.label_bytes = 1
        self.image_bytes = self.height * self.width * self.channel
        self.bytes = self.label_bytes + self.image_bytes
  • 實現讀取資料方法bytes_read(self, file_list)

    • 構造檔案佇列
    # 1、構造檔案佇列
    file_queue = tf.train.string_input_producer(file_list)
    
    • tf.FixedLengthRecordReader(bytes)讀取
    # 2、使用tf.FixedLengthRecordReader(bytes)讀取
    # 預設必須指定讀取一個樣本
    reader = tf.FixedLengthRecordReader(self.all_bytes)
    
    _, value = reader.read(file_queue)
    
    • 進行解碼操作
    # 3、解碼操作
    # (?, )   (3073, ) = label(1, ) + feature(3072, )
    label_image = tf.decode_raw(value, tf.uint8)
    # 為了訓練方便,一般會把特徵值和目標值分開處理
    print(label_image)
    
    • 將資料的標籤和圖片進行分割
    # 使用tf.slice進行切片
    label = tf.cast(tf.slice(label_image, [0], [self.label_bytes]), tf.int32)
    
    image = tf.slice(label_image, [self.label_bytes], [self.image_bytes])
    
    print(label, image)
    
    • 處理資料的形狀,並且進行批處理
    # 處理型別和圖片資料的形狀
    # 圖片形狀
    # reshape (3072, )----[channel, height, width]
    # transpose [channel, height, width] --->[height, width, channel]
    depth_major = tf.reshape(image, [self.channel, self.height, self.width])
    print(depth_major)
    
    image_reshape = tf.transpose(depth_major, [1, 2, 0])
    
    print(image_reshape)
    
    # 4、批處理
    image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10)
    

這裡的圖片形狀設定從1維的排列到3維資料的時候,涉及到NHWC與NCHW的概念:

1)NHWC與NCHW

在讀取設定圖片形狀的時候有兩種格式:

設定為 "NHWC" 時,排列順序為 [batch, height, width, channels];

設定為 "NCHW" 時,排列順序為 [batch, channels, height, width]。

其中 N 表示這批影象有幾張,H 表示影象在豎直方向有多少畫素,W 表示水平方向畫素數,C 表示通道數。

Tensorflow預設的[height, width, channel]

假設RGB三通道兩種格式的區別如下圖所示:

1 理解

假設1, 2, 3, 4-紅色 5, 6, 7, 8-綠色 9, 10, 11, 12-藍色

  • 如果通道在最低維度0[channel, height, width],RGB三顏色分成三組,在第一維度上找到三個RGB顏色
  • 如果通道在最高維度2[height, width, channel],在第三維度上找到RGB三個顏色

# 1、想要變成:[2 height, 2width,  3channel],但是輸出結果不對
In [7]: tf.reshape([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 2, 3]).eval()
Out[7]:
array([[[ 1,  2,  3],
        [ 4,  5,  6]],

       [[ 7,  8,  9],
        [10, 11, 12]]], dtype=int32)

# 2、所以要這樣去做
In [8]: tf.reshape([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [3, 2, 2]).eval()
Out[8]:
array([[[ 1,  2],
        [ 3,  4]],

       [[ 5,  6],
        [ 7,  8]],

       [[ 9, 10],
        [11, 12]]], dtype=int32)
# 接著使用tf.transpose ,0,1,2代表三個維度標記
# Convert from [depth, height, width] to [height, width, depth].
# 0,1,2-----> 1, 2, 0
In [17]: tf.transpose(depth_major, [1, 2, 0]).eval()
Out[17]:
array([[[ 1,  5,  9],
        [ 2,  6, 10]],

       [[ 3,  7, 11],
        [ 4,  8, 12]]], dtype=int32)

2 轉換API

  • tf.transpose(a, perm=None)
    • Transposes a. Permutes the dimensions according to perm.
      • 修改維度的位置
    • a:資料
    • perm:形狀的維度值下標列表

2)處理圖片的形狀

所以在讀取資料處理形狀的時候

  • 1 image (3072, ) —>tf.reshape(image, [])裡面的shape是[channel, height, width], 所以得先從[depth height width] to [depth, height, width]
  • 2 然後使用tf.transpose, 將剛才的資料[depth, height, width],變成Tensorflow預設的[height, width, channel]

3 完整程式碼

import tensorflow as tf
import os


class Cifar(object):

    # 初始化
    def __init__(self):
        # 影象的大小
        self.height = 32
        self.width = 32
        self.channels = 3

        # 影象的位元組數
        self.label_bytes = 1
        self.image_bytes = self.height * self.width * self.channels
        self.bytes = self.label_bytes + self.image_bytes

    def read_and_decode(self, file_list):
        # 讀取二進位制檔案
        # print("read_and_decode:\n", file_list)
        # 1、構造檔名佇列
        file_queue = tf.train.string_input_producer(file_list)

        # 2、構造二進位制檔案閱讀器
        reader = tf.FixedLengthRecordReader(self.bytes)
        key, value = reader.read(file_queue)

        print("key:\n", key)
        print("value:\n", value)
        # 3、解碼
        decoded = tf.decode_raw(value, tf.uint8)
        print("decoded:\n", decoded)

        # 4、基本的資料處理
        # 切片處理,把標籤值和特徵值分開
        label = tf.slice(decoded, [0], [self.label_bytes])
        image = tf.slice(decoded, [self.label_bytes], [self.image_bytes])

        print("label:\n", label)
        print("image:\n", image)
        # 改變影象的形狀
        image_reshaped = tf.reshape(image, [self.channels, self.height, self.width])
        # 轉置
        image_transposed = tf.transpose(image_reshaped, [1, 2, 0])
        print("image_transposed:\n", image_transposed)

        # 型別轉換
        label_cast = tf.cast(label, tf.float32)
        image_cast = tf.cast(image_transposed, tf.float32)

        # 5、批處理
        label_batch, image_batch = tf.train.batch([label_cast, image_cast], batch_size=10, num_threads=1, capacity=10)
        return label_batch, image_batch


if __name__ == "__main__":
    # 構造檔名列表
    file_name = os.listdir("./cifar-10-batches-bin")
    print("file_name:\n", file_name)
    file_list = [os.path.join("./cifar-10-batches-bin/", file) for file in file_name if file[-3:] == "bin"]
    print("file_list:\n", file_list)

    # 呼叫讀取二進位制檔案的方法
    cf = Cifar()
    label, image = cf.read_and_decode(file_list)

    # 開啟會話
    with tf.Session() as sess:
        # 建立執行緒協調器
        coord = tf.train.Coordinator()
        # 建立執行緒
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        # 列印結果
        print("label:\n", sess.run(label))
        print("image:\n", sess.run(image))

        # 回收資源
        coord.request_stop()
        coord.join(threads)