TensorFlow中的TFRecord檔案(轉)
轉自:https://yinguobing.com/tfrecord-in-tensorflow/
背景:最近在學習TensorFlow,需要將自定義影象資料作為訓練資料。
標準TensorFlow格式
TensorFlow的訓練過程其實就是大量的資料在網路中不斷流動的過程,而資料的來源在官方文件[^1](API r1.2)中介紹了三種方式,分別是:
- Feeding。通過Python直接注入資料。
- Reading from files。從檔案讀取資料,本文中的TFRecord屬於此類方式。
- Preloaded data。將資料以constant或者variable的方式直接儲存在運算圖中。
當資料量較大時,官方推薦採用標準TensorFlow格式[^2](Standard TensorFlow format)來儲存訓練與驗證資料,該格式的字尾名為tfrecord
。官方介紹如下:
A TFRecords file represents a sequence of (binary) strings. The format is not random access, so it is suitable for streaming large amounts of data but not suitable if fast sharding or other non-sequential access is desired.
從介紹不難看出,TFRecord檔案適用於大量資料的順序讀取。而這正好是神經網路在訓練過程中發生的事情。
如何使用TFRecord檔案
對於TFRecord檔案的使用,官方給出了兩份示例程式碼,分別展示瞭如何生成與讀取該格式的檔案。
生成TFRecord檔案
第一份程式碼convert_to_records.py
[^3]將MNIST裡的影象資料轉換為了TFRecord格式 。仔細研讀程式碼,可以發現TFRecord檔案中的影象資料儲存在Feature
下的image_raw
裡。image_raw
來自於data_set.images
,而後者又來自mnist.read_data_sets()
images
的真身藏在mnist.py
這個檔案裡。
mnist.py
並不難找,在Pycharm裡按下ctrl
後單擊滑鼠左鍵即可開啟原始碼。
繼續追蹤,可以在mnist裡發現影象來自extract_images()
函式。該函式的說明裡清晰的寫明:
Extract the images into a 4D uint8 numpy array [index, y, x, depth].
Args:
f: A file object that can be passed into a gzip reader.
Returns:
data: A 4D uint8 numpy array [index, y, x, depth].
Raises:
ValueError: If the bytestream does not start with 2051.
很明顯,返回值變數名為data
,是一個4D Numpy矩陣,儲存值為uint8
型別,即影象畫素的灰度值(MNIST全部為灰度影象)。四個維度分別代表了:影象的個數,每個影象行數,每個影象列數,每個影象通道數。
在獲得這個儲存著畫素灰度值的Numpy矩陣後,使用numpy的tostring()
函式將其轉換為Python bytes格式[^4],再使用tf.train.BytesList()
函式封裝為tf.train.BytesList
類,名字為image_raw
。最後使用tf.train.Example()
將image_raw
和其它屬性一遍打包,並呼叫tf.python_io.TFRecordWriter
將其寫入到檔案中。
至此,TFRecord檔案生成完畢。
可見,將自定義影象轉換為TFRecord的過程本質上是將大量影象的畫素灰度值轉換為Python bytes,並與其它Feature
組合在一起,最終拼接成一個檔案的過程。
需要注意的是其它Feature
的型別不一定必須是BytesList,還可以是Int64List或者FloatList。
讀取TFRecord檔案
第二份程式碼fully_connected_reader.py
[1]展示瞭如何從TFRecord檔案中讀取資料。
讀取資料的函式名為input()
。函式內部首先通過tf.train.string_input_producer()
函式讀取TFRecord檔案,並返回一個queue
;然後使用read_and_decode()
讀取一份資料,函式內部用tf.decode_raw()
解析出影象的灰度值,用tf.cast()
解析出label
的值。之後通過tf.train.shuffle_batch()
的方法生成一批用來訓練的資料。並最終返回可供訓練的images
和labels
,並送入inference
部分進行計算。
在這個過程中,有以下幾點需要留意:
tf.decode_raw()
解析出的資料是沒有shape
的,因此需要呼叫set_shape()
函式來給出tensor的維度。read_and_decode()
函式返回的是單個的資料,但是後邊的tf.train.shuffle_batch()
卻能夠生成批量資料。- 如果需要對影象進行處理的話,需要放在第二項提到的兩個函式中間。
其中第2點的原理我暫時沒有弄懂。從程式碼上看read_and_decode()
返回的是單個數據,shuffle_batch
接收到的也是單個數據,不知道是如何生成批量資料的,猜測與queue
有關係。
所以,讀取TFRecord檔案的本質,就是通過佇列的方式依次將資料解碼,並按需要進行資料隨機化、影象隨機化的過程。