1. 程式人生 > >TensorFlow中的TFRecord檔案(轉)

TensorFlow中的TFRecord檔案(轉)

轉自:https://yinguobing.com/tfrecord-in-tensorflow/

背景:最近在學習TensorFlow,需要將自定義影象資料作為訓練資料。


標準TensorFlow格式

TensorFlow的訓練過程其實就是大量的資料在網路中不斷流動的過程,而資料的來源在官方文件[^1](API r1.2)中介紹了三種方式,分別是:

  • Feeding。通過Python直接注入資料。
  • Reading from files。從檔案讀取資料,本文中的TFRecord屬於此類方式。
  • Preloaded data。將資料以constant或者variable的方式直接儲存在運算圖中。

當資料量較大時,官方推薦採用標準TensorFlow格式[^2](Standard TensorFlow format)來儲存訓練與驗證資料,該格式的字尾名為tfrecord。官方介紹如下:

A TFRecords file represents a sequence of (binary) strings. The format is not random access, so it is suitable for streaming large amounts of data but not suitable if fast sharding or other non-sequential access is desired.

從介紹不難看出,TFRecord檔案適用於大量資料的順序讀取。而這正好是神經網路在訓練過程中發生的事情。


如何使用TFRecord檔案

對於TFRecord檔案的使用,官方給出了兩份示例程式碼,分別展示瞭如何生成與讀取該格式的檔案。

生成TFRecord檔案

第一份程式碼convert_to_records.py [^3]將MNIST裡的影象資料轉換為了TFRecord格式 。仔細研讀程式碼,可以發現TFRecord檔案中的影象資料儲存在Feature下的image_raw裡。image_raw來自於data_set.images,而後者又來自mnist.read_data_sets()

。因此images的真身藏在mnist.py這個檔案裡。

mnist.py並不難找,在Pycharm裡按下ctrl後單擊滑鼠左鍵即可開啟原始碼。

繼續追蹤,可以在mnist裡發現影象來自extract_images()函式。該函式的說明裡清晰的寫明:

Extract the images into a 4D uint8 numpy array [index, y, x, depth].
  Args:
    f: A file object that can be passed into a gzip reader.
  Returns:
    data: A 4D uint8 numpy array [index, y, x, depth].
  Raises:
    ValueError: If the bytestream does not start with 2051.

很明顯,返回值變數名為data,是一個4D Numpy矩陣,儲存值為uint8型別,即影象畫素的灰度值(MNIST全部為灰度影象)。四個維度分別代表了:影象的個數,每個影象行數,每個影象列數,每個影象通道數。

在獲得這個儲存著畫素灰度值的Numpy矩陣後,使用numpy的tostring()函式將其轉換為Python bytes格式[^4],再使用tf.train.BytesList()函式封裝為tf.train.BytesList類,名字為image_raw。最後使用tf.train.Example()image_raw和其它屬性一遍打包,並呼叫tf.python_io.TFRecordWriter將其寫入到檔案中。

至此,TFRecord檔案生成完畢。

可見,將自定義影象轉換為TFRecord的過程本質上是將大量影象的畫素灰度值轉換為Python bytes,並與其它Feature組合在一起,最終拼接成一個檔案的過程

需要注意的是其它Feature的型別不一定必須是BytesList,還可以是Int64List或者FloatList。

讀取TFRecord檔案

第二份程式碼fully_connected_reader.py [1]展示瞭如何從TFRecord檔案中讀取資料。

讀取資料的函式名為input()。函式內部首先通過tf.train.string_input_producer()函式讀取TFRecord檔案,並返回一個queue;然後使用read_and_decode()讀取一份資料,函式內部用tf.decode_raw()解析出影象的灰度值,用tf.cast()解析出label的值。之後通過tf.train.shuffle_batch()的方法生成一批用來訓練的資料。並最終返回可供訓練的imageslabels,並送入inference部分進行計算。

在這個過程中,有以下幾點需要留意:

  1. tf.decode_raw()解析出的資料是沒有shape的,因此需要呼叫set_shape()函式來給出tensor的維度。
  2. read_and_decode()函式返回的是單個的資料,但是後邊的tf.train.shuffle_batch()卻能夠生成批量資料。
  3. 如果需要對影象進行處理的話,需要放在第二項提到的兩個函式中間。

其中第2點的原理我暫時沒有弄懂。從程式碼上看read_and_decode()返回的是單個數據,shuffle_batch接收到的也是單個數據,不知道是如何生成批量資料的,猜測與queue有關係。

所以,讀取TFRecord檔案的本質,就是通過佇列的方式依次將資料解碼,並按需要進行資料隨機化、影象隨機化的過程