編寫基於TensorFlow的應用之構建資料pipeline
本文主要以MNIST資料集為例介紹TFRecords檔案如何製作以及載入使用。所講內容可以在SIGAI 線上程式設計功能中的sharedata/intro_to_tf資料夾中可以免費獲取。此項功能對所有註冊使用者免費開放.
官網地址:www.sigai.cn,
推薦使用chrome瀏覽器
線上程式設計功能使用指南見SIGAI官網->知識庫->線上程式設計使用說明小視訊
圖1 典型的基於TensorFlow 的應用的workflow
通常情況下,一個基於TensorFlow 的應用訓練過程中所
採用的workflow 如圖1 所示。針對與原始資料的格式,首先採用不同的轉換方式在執行過程中生成Tensor格式的資料,然後將其送到TensorFlowGraph中執行,根據設定的目標函式,不斷的在訓練資料上迭代並週期性地儲存checkpoint到檔案中,checkpoint檔案可以用於後續的模型持久化操作。TensorFlow框架下訓練輸入pipeline是一個標準的ETL過程:
1. 提取資料(Extract): 從儲存空間內部讀取原始資料
2. 資料轉換(Transform): 使用CPU解析原始資料並執行一些預處理的操作: 文字資料轉換為陣列,圖片大小變換,圖片資料增強操作等等
3. 資料載入(Load): 載入轉換後的資料並傳給GPU,FPGA,ASIC等加速晶片進行計算
在TensorFlow框架之下,使用 tf.dataset API 可以完成上述過程中所需的所有操作,其過程如下圖所示:
圖2 TensorFlow中的ETL過程
相較於TFRecords檔案,文字檔案,numpy陣列,csv檔案等檔案格式更為常見。接下來,本文將以常用的MNIST資料集為例簡要介紹TFRecord檔案如何生成以及如何從TFrecord構建資料pipeline。
Record檔案簡介
TFRecord檔案是基於Google Protocol Buffers的一種儲存資料的格式,我們推薦在資料預處理過程中儘可能使用這種方式將訓練資料儲存成這種格式。Protocol Buffers 是一種簡潔高效的序列化格式化的方法,其採用了語言無關,平臺無關且可擴充套件的機制。 採用這種方式的優勢在於:
1. 採用二進位制格式儲存,減少儲存空間,提高讀取效率
2. 針對TensorFlow框架進行優化,支援合併多個數據源,並且支援TensorFlow內建的其他資料預處理方式
3. 支援序列化資料的儲存(時序資料或者詞向量)
圖3 TFRecord檔案中儲存內容結構
TFRecords中儲存的層級如圖3所示,從圖中可以看到:
‣ 一個TFRecord檔案中包含了多個tf.train.Example, 每個tf.train.Example是一個Protocol Buffer
‣ 每個tf.train.Example包含了tf.train.Features
‣ 每個tf.train.Features是由多個feature 構成的feature set
以MNIST為例生成TFRecord檔案
圖4 TFRecord檔案製作和載入過程
從原始檔案生成TFRecord的過程如圖4所示:
1. 從檔案中讀取資料資訊,如果是類別,長度,高度等數值型資料就轉換成Int64List, FloatList格式的特徵,如果是圖片等raw data,則直接讀取其二進位制編碼內容,再轉換成BytesList即可
2. 將多個特徵合併為 tf.train.Features,並傳遞到tf.train.Example中
3. 最後使用TFRecordWriter寫入到檔案中
對於MNIST檔案,從http://yann.lecun.com/exdb/mnist/網站下載下來的是以二進位制方式儲存的資料集,本文略過下載並讀取MNIST為numpy 陣列的過程,有興趣的讀者可以檢視mnist_data.py中的read_mnist函式。接下來我們重要講解從一個numpy 陣列到tfrecord檔案需要執行的主要步驟:
1. 對於整個陣列,需要遍歷整個陣列並依次將其轉換成一個tf.train.Example(protocol buffer)
def feature_to_example(img, label): """convert numpy array to a `tf.train.example` Args: img : An `np.ndarray`. Img in numpy array format label : An `np.int32`. label of the image """ # convert raw data corresponding to the numpy array in memory into pytho bytes img = img.tostring() return tf.train.Example( features=tf.train.Features( feature={ 'img': bytes_feature(img), 'label': int_feature(label) } ) ) 這其中使用到的bytes_feature和int_feature分別是用來將圖片和標籤轉換成二進位制的feature和int列表的特徵的函式
def int_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
3. 在使用SerializeToString函式將protocol buffer中的內容序列化之後, 將其內容寫入到檔案中
至此,MNIST的tfrecord檔案就製作完成了。以上步驟各位讀者可以在sharedat/intro_to_tf路徑下的tfrecords.ipynb 檔案中進行實驗。下載MNIST資料集過程需要消耗一定時間,請各位耐心等待。[1] 由於MNIST中涉及到的特徵僅有陣列和標籤兩類內容,對於讀者在使用TensorFlow過程中可能會遇到的其他資料格式,建議參考 https://github.com/tensorflow/models/blob/master/research/object_detection/dataset_tools/create_pascal_tf_record.py 檔案編寫適合自己資料集內容的函式
載入TFRecord檔案並構建資料pipeline
從圖4中,可以看到載入一個TFRrecord檔案需要執行的步驟,其過程中使用了TensorFlow dataset類提供的函式:
1. shuffle:打亂輸入資料的順序
2. repeat: 重複資料集內容若干次
3. map: 對資料集中的每個資料使用map函式中傳入的方法進行變換,這個過程中可以包含解析tf.train.Example內容,資料歸一化以及以及data augmentation等其他操作
4. batch: 根據需要設定每次訓練採用多少資料
5. prefetch:提前載入n個數據,保證每個session執行之前資料是可以立即使用的
在mnist_tfrecords.py檔案中有兩個不同的載入資料的方式,我們建議使用第二種優化過的載入方式,其優點在於:
1. shuffle_and_repeat可以保證載入資料的速度以及確保資料之間的順序正確
2. map_and_batch 整合了map和batch 過程,提高了載入效率
經過優化過的載入TFRecord檔案函式如下:
def load_data_optimized(cache_dir='data/cache', split='train', batch_size=64, epochs_between_evals=3): tfrecord_file = os.path.join(cache_dir, 'mnist_{}.tfrecord'.format(split))
# load the tfrecord data dataset = tf.data.TFRecordDataset(tfrecord_file)
# shuffle and repeat data if split == 'train': dataset = dataset.apply(shuffle_and_repeat(60000, epochs_between_evals)) else: dataset = dataset.apply(shuffle_and_repeat(10000, epochs_between_evals))
# fuse map and batch dataset = dataset.apply(map_and_batch(parse_example, batch_size=batch_size, drop_remainder=True, num_parallel_calls=8))
dataset = dataset.prefetch(1)
return dataset
在SIGAI提供的實驗過程中,驗證讀取資料的內容如下圖所示:
圖5 驗證載入MNIST資料集
本文主要介紹了TFRecord檔案,然後以MNIST資料集為例講解了如何製作MNIST資料集的TFRecord檔案,接著講述瞭如何載入檔案並構建資料 pipeline。大家在做實驗過程中使用了Eager模式,我們將在下一篇文章中介紹Eager 模式的使用。