TensorFlow——二進位制資料讀取
阿新 • • 發佈:2018-12-10
一、CIFAR10二進位制資料集介紹
https://www.cs.toronto.edu/~kriz/cifar.html
- 二進位制版本資料檔案
二進位制版本包含檔案data_batch_1.bin,data_batch_2.bin,...,data_batch_5.bin以及test_batch.bin
。這些檔案中的每一個格式如下,資料中每個樣本包含了特徵值和目標值:
<1×標籤> <3072×畫素>
...
<1×標籤> <3072×畫素>
第一個位元組是第一個影象的標籤,它是一個0-9範圍內的數字。接下來的3072個位元組是影象畫素的值。前1024個位元組是紅色通道值,下1024個綠色,最後1024個藍色。
二、CIFAR10 二進位制資料讀取
1.分析
- 構造檔案佇列
- 讀取二進位制資料並進行解碼
- 處理圖片資料形狀以及資料型別,批處理返回
- 開啟會話執行緒執行
2.程式碼
- 定義CIFAR類,設定圖片相關的屬性
class CifarRead(object): """ 二進位制檔案的讀取,tfrecords儲存讀取 """ def __init__(self): # 定義一些圖片的屬性 self.height = 32 self.width = 32 self.channel = 3 self.label_bytes = 1 self.image_bytes = self.height * self.width * self.channel self.bytes = self.label_bytes + self.image_bytes
-
實現讀取資料方法bytes_read(self, file_list)
- 構造檔案佇列
# 1、構造檔案佇列 file_queue = tf.train.string_input_producer(file_list)
- tf.FixedLengthRecordReader(bytes)讀取
# 2、使用tf.FixedLengthRecordReader(bytes)讀取 # 預設必須指定讀取一個樣本 reader = tf.FixedLengthRecordReader(self.all_bytes) _, value = reader.read(file_queue)
- 進行解碼操作
# 3、解碼操作 # (?, ) (3073, ) = label(1, ) + feature(3072, ) label_image = tf.decode_raw(value, tf.uint8) # 為了訓練方便,一般會把特徵值和目標值分開處理 print(label_image)
- 將資料的標籤和圖片進行分割
# 使用tf.slice進行切片 label = tf.cast(tf.slice(label_image, [0], [self.label_bytes]), tf.int32) image = tf.slice(label_image, [self.label_bytes], [self.image_bytes]) print(label, image)
- 處理資料的形狀,並且進行批處理
# 處理型別和圖片資料的形狀 # 圖片形狀 # reshape (3072, )----[channel, height, width] # transpose [channel, height, width] --->[height, width, channel] depth_major = tf.reshape(image, [self.channel, self.height, self.width]) print(depth_major) image_reshape = tf.transpose(depth_major, [1, 2, 0]) print(image_reshape) # 4、批處理 image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10)
這裡的圖片形狀設定從1維的排列到3維資料的時候,涉及到NHWC與NCHW的概念:
1)NHWC與NCHW
在讀取設定圖片形狀的時候有兩種格式:
設定為 "NHWC" 時,排列順序為 [batch, height, width, channels];
設定為 "NCHW" 時,排列順序為 [batch, channels, height, width]。
其中 N 表示這批影象有幾張,H 表示影象在豎直方向有多少畫素,W 表示水平方向畫素數,C 表示通道數。
Tensorflow預設的[height, width, channel]
假設RGB三通道兩種格式的區別如下圖所示:
1 理解
假設1, 2, 3, 4-紅色 5, 6, 7, 8-綠色 9, 10, 11, 12-藍色
- 如果通道在最低維度0[channel, height, width],RGB三顏色分成三組,在第一維度上找到三個RGB顏色
- 如果通道在最高維度2[height, width, channel],在第三維度上找到RGB三個顏色
# 1、想要變成:[2 height, 2width, 3channel],但是輸出結果不對
In [7]: tf.reshape([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 2, 3]).eval()
Out[7]:
array([[[ 1, 2, 3],
[ 4, 5, 6]],
[[ 7, 8, 9],
[10, 11, 12]]], dtype=int32)
# 2、所以要這樣去做
In [8]: tf.reshape([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [3, 2, 2]).eval()
Out[8]:
array([[[ 1, 2],
[ 3, 4]],
[[ 5, 6],
[ 7, 8]],
[[ 9, 10],
[11, 12]]], dtype=int32)
# 接著使用tf.transpose ,0,1,2代表三個維度標記
# Convert from [depth, height, width] to [height, width, depth].
# 0,1,2-----> 1, 2, 0
In [17]: tf.transpose(depth_major, [1, 2, 0]).eval()
Out[17]:
array([[[ 1, 5, 9],
[ 2, 6, 10]],
[[ 3, 7, 11],
[ 4, 8, 12]]], dtype=int32)
2 轉換API
- tf.transpose(a, perm=None)
- Transposes
a
. Permutes the dimensions according toperm
.- 修改維度的位置
- a:資料
- perm:形狀的維度值下標列表
- Transposes
2)處理圖片的形狀
所以在讀取資料處理形狀的時候
- 1 image (3072, ) —>tf.reshape(image, [])裡面的shape是[channel, height, width], 所以得先從[depth height width] to [depth, height, width]。
- 2 然後使用tf.transpose, 將剛才的資料[depth, height, width],變成Tensorflow預設的[height, width, channel]
3 完整程式碼
import tensorflow as tf
import os
class Cifar(object):
# 初始化
def __init__(self):
# 影象的大小
self.height = 32
self.width = 32
self.channels = 3
# 影象的位元組數
self.label_bytes = 1
self.image_bytes = self.height * self.width * self.channels
self.bytes = self.label_bytes + self.image_bytes
def read_and_decode(self, file_list):
# 讀取二進位制檔案
# print("read_and_decode:\n", file_list)
# 1、構造檔名佇列
file_queue = tf.train.string_input_producer(file_list)
# 2、構造二進位制檔案閱讀器
reader = tf.FixedLengthRecordReader(self.bytes)
key, value = reader.read(file_queue)
print("key:\n", key)
print("value:\n", value)
# 3、解碼
decoded = tf.decode_raw(value, tf.uint8)
print("decoded:\n", decoded)
# 4、基本的資料處理
# 切片處理,把標籤值和特徵值分開
label = tf.slice(decoded, [0], [self.label_bytes])
image = tf.slice(decoded, [self.label_bytes], [self.image_bytes])
print("label:\n", label)
print("image:\n", image)
# 改變影象的形狀
image_reshaped = tf.reshape(image, [self.channels, self.height, self.width])
# 轉置
image_transposed = tf.transpose(image_reshaped, [1, 2, 0])
print("image_transposed:\n", image_transposed)
# 型別轉換
label_cast = tf.cast(label, tf.float32)
image_cast = tf.cast(image_transposed, tf.float32)
# 5、批處理
label_batch, image_batch = tf.train.batch([label_cast, image_cast], batch_size=10, num_threads=1, capacity=10)
return label_batch, image_batch
if __name__ == "__main__":
# 構造檔名列表
file_name = os.listdir("./cifar-10-batches-bin")
print("file_name:\n", file_name)
file_list = [os.path.join("./cifar-10-batches-bin/", file) for file in file_name if file[-3:] == "bin"]
print("file_list:\n", file_list)
# 呼叫讀取二進位制檔案的方法
cf = Cifar()
label, image = cf.read_and_decode(file_list)
# 開啟會話
with tf.Session() as sess:
# 建立執行緒協調器
coord = tf.train.Coordinator()
# 建立執行緒
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
# 列印結果
print("label:\n", sess.run(label))
print("image:\n", sess.run(image))
# 回收資源
coord.request_stop()
coord.join(threads)