TensorFlow讀寫數據
前言
只有光頭才能變強。
文本已收錄至我的GitHub倉庫,歡迎Star:https://github.com/ZhongFuCheng3y/3y
回顧前面:
- 從零開始學TensorFlow【01-搭建環境、HelloWorld篇】
- 什麽是TensorFlow?
眾所周知,要訓練出一個模型,首先我們得有數據。我們第一個例子中,直接使用dataset的api去加載mnist的數據。(minst的數據要麽我們是提前下載好,放在對應的目錄上,要麽就根據他給的url直接從網上下載)。
一般來說,我們使用TensorFlow是從TFRecord文件中讀取數據的。
TFRecord 文件格式是一種面向記錄的簡單二進制格式
,很多 TensorFlow 應用采用此格式來訓練數據
所以,這篇文章來聊聊怎麽讀取TFRecord文件的數據。
一、入門對數據集的數據進行讀和寫
首先,我們來體驗一下怎麽造一個TFRecord文件,怎麽從TFRecord文件中讀取數據,遍歷(消費)這些數據。
1.1 造一個TFRecord文件
現在,我們還沒有TFRecord文件,我們可以自己簡單寫一個:
def write_sample_to_tfrecord(): gmv_values = np.arange(10) click_values = np.arange(10) label_values = np.arange(10) with tf.python_io.TFRecordWriter("/Users/zhongfucheng/data/fashin/demo.tfrecord", options=None) as writer: for _ in range(10): feature_internal = { "gmv": tf.train.Feature(float_list=tf.train.FloatList(value=[gmv_values[_]])), "click": tf.train.Feature(int64_list=tf.train.Int64List(value=[click_values[_]])), "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label_values[_]])) } features_extern = tf.train.Features(feature=feature_internal) # 使用tf.train.Example將features編碼數據封裝成特定的PB協議格式 # example = tf.train.Example(features=tf.train.Features(feature=features_extern)) example = tf.train.Example(features=features_extern) # 將example數據系列化為字符串 example_str = example.SerializeToString() # 將系列化為字符串的example數據寫入協議緩沖區 writer.write(example_str) if __name__ == '__main__': write_sample_to_tfrecord()
我相信大家代碼應該是能夠看得懂的,其實就是分了幾步:
- 生成TFRecord Writer
- tf.train.Feature生成協議信息
- 使用tf.train.Example將features編碼數據封裝成特定的PB協議格式
- 將example數據系列化為字符串
- 將系列化為字符串的example數據寫入協議緩沖區
參考資料:
- https://zhuanlan.zhihu.com/p/31992460
ok,現在我們就有了一個TFRecord文件啦。
1.2 讀取TFRecord文件
其實就是通過
tf.data.TFRecordDataset
這個api來讀取到TFRecord文件,生成處dataset對象- 對dataset進行處理(shape處理,格式處理...等等)
使用叠代器對dataset進行消費(遍歷)
demo代碼如下:
import tensorflow as tf
def read_tensorflow_tfrecord_files():
# 定義消費緩沖區協議的parser,作為dataset.map()方法中傳入的lambda:
def _parse_function(single_sample):
features = {
"gmv": tf.FixedLenFeature([1], tf.float32),
"click": tf.FixedLenFeature([1], tf.int64), # ()或者[]沒啥影響
"label": tf.FixedLenFeature([1], tf.int64)
}
parsed_features = tf.parse_single_example(single_sample, features=features)
# 對parsed 之後的值進行cast.
gmv = tf.cast(parsed_features["gmv"], tf.float64)
click = tf.cast(parsed_features["click"], tf.float64)
label = tf.cast(parsed_features["label"], tf.float64)
return gmv, click, label
# 開始定義dataset以及解析tfrecord格式
filenames = tf.placeholder(tf.string, shape=[None])
# 定義dataset 和 一些列trasformation method
dataset = tf.data.TFRecordDataset(filenames)
parsed_dataset = dataset.map(_parse_function) # 消費緩沖區需要定義在dataset 的map 函數中
batchd_dataset = parsed_dataset.batch(3)
# 創建Iterator
sample_iter = batchd_dataset.make_initializable_iterator()
# 獲取next_sample
gmv, click, label = sample_iter.get_next()
training_filenames = [
"/Users/zhongfucheng/data/fashin/demo.tfrecord"]
with tf.Session() as session:
# 初始化帶參數的Iterator
session.run(sample_iter.initializer, feed_dict={filenames: training_filenames})
# 讀取文件
print(session.run(gmv))
if __name__ == '__main__':
read_tensorflow_tfrecord_files()
無意外的話,我們可以輸出這樣的結果:
[[0.]
[1.]
[2.]]
ok,現在我們已經大概知道怎麽寫一個TFRecord文件,以及怎麽讀取TFRecord文件的數據,並且消費這些數據了。
二、epoch和batchSize術語解釋
我在學習TensorFlow翻閱資料時,經常看到一些機器學習的術語,由於自己沒啥機器學習的基礎,所以很多時候看到一些專業名詞就開始懵逼了。
2.1epoch
當一個完整的數據集通過了神經網絡一次並且返回了一次,這個過程稱為一個epoch。
這可能使我們跟dataset.repeat()
方法聯系起來,這個方法可以使當前數據集重復一遍。比如說,原有的數據集是[1,2,3,4,5]
,如果我調用dataset.repeat(2)
的話,那麽我們的數據集就變成了[1,2,3,4,5],[1,2,3,4,5]
- 所以會有個說法:假設原先的數據是一個epoch,使用repeat(5)就可以將之變成5個epoch
2.2batchSize
一般來說我們的數據集都是比較大的,無法一次性將整個數據集的數據餵進神經網絡中,所以我們會將數據集分成好幾個部分。每次餵多少條樣本進神經網絡,這個叫做batchSize。
在TensorFlow也提供了方法給我們設置:dataset.batch()
,在API中是這樣介紹batchSize的:
representing the number of consecutive elements of this dataset to combine in a single batch
我們一般在每次訓練之前,會將整個數據集的順序打亂,提高我們模型訓練的效果。這裏我們用到的api是:dataset.shffle();
三、再來聊聊dataset
我從官網的介紹中截了一個dataset的方法圖(部分):
dataset的功能主要有以下三種:
- 創建dataset實例
- 通過文件創建(比如TFRecord)
- 通過內存創建
- 對數據集的數據進行變換
- 比如上面的batch(),常見的
map(),flat_map(),zip(),repeat()
等等 - 文檔中一般都有給出例子,跑一下一般就知道對應的意思了。
- 比如上面的batch(),常見的
- 創建叠代器,遍歷數據集的數據
3.1 聊聊叠代器
叠代器可以分為四種:
- 單次。對數據集進行一次叠代,不支持參數化
- 可初始化叠代
- 使用前需要進行初始化,支持傳入參數。面向的是同一個DataSet
- 可重新初始化:同一個Iterator從不同的DataSet中讀取數據
- DataSet的對象具有相同的結構,可以使用
tf.data.Iterator.from_structure
來進行初始化 - 問題:每次 Iterator 切換時,數據都從頭開始打印了
- DataSet的對象具有相同的結構,可以使用
- 可饋送(也是通過對象相同的結果來創建的叠代器)
- 可讓您在兩個數據集之間切換的可饋送叠代器
- 通過一個string handler來實現。
- 可饋送的 Iterator 在不同的 Iterator 切換的時候,可以做到不從頭開始。
簡單總結:
- 1、 單次 Iterator ,它最簡單,但無法重用,無法處理數據集參數化的要求。
- 2、 可以初始化的 Iterator ,它可以滿足 Dataset 重復加載數據,滿足了參數化要求。
- 3、可重新初始化的 Iterator,它可以對接不同的 Dataset,也就是可以從不同的 Dataset 中讀取數據。
- 4、可饋送的 Iterator,它可以通過 feeding 的方式,讓程序在運行時候選擇正確的 Iterator,它和可重新初始化的 Iterator 不同的地方就是它的數據在不同的 Iterator 切換時,可以做到不重頭開始讀取數據。
string handler(可饋送的 Iterator)這種方式是最常使用的,我當時也寫了一個Demo來使用了一下,代碼如下:
def read_tensorflow_tfrecord_files():
# 開始定義dataset以及解析tfrecord格式.
train_filenames = tf.placeholder(tf.string, shape=[None])
vali_filenames = tf.placeholder(tf.string, shape=[None])
# 加載train_dataset batch_inputs這個方法每個人都不一樣的,這個方法我就不給了。
train_dataset = batch_inputs([
train_filenames], batch_size=5, type=False,
num_epochs=2, num_preprocess_threads=3)
# 加載validation_dataset batch_inputs這個方法每個人都不一樣的,這個方法我就不給了。
validation_dataset = batch_inputs([vali_filenames
], batch_size=5, type=False,
num_epochs=2, num_preprocess_threads=3)
# 創建出string_handler()的叠代器(通過相同數據結構的dataset來構建)
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
handle, train_dataset.output_types, train_dataset.output_shapes)
# 有了叠代器就可以調用next方法了。
itemid = iterator.get_next()
# 指定哪種具體的叠代器,有單次叠代的,有初始化的。
training_iterator = train_dataset.make_initializable_iterator()
validation_iterator = validation_dataset.make_initializable_iterator()
# 定義出placeholder的值
training_filenames = [
"/Users/zhongfucheng/tfrecord_test/data01aa"]
validation_filenames = ["/Users/zhongfucheng/tfrecord_validation/part-r-00766"]
with tf.Session() as sess:
# 初始化叠代器
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())
for _ in range(2):
sess.run(training_iterator.initializer, feed_dict={train_filenames: training_filenames})
print("this is training iterator ----")
for _ in range(5):
print(sess.run(itemid, feed_dict={handle: training_handle}))
sess.run(validation_iterator.initializer,
feed_dict={vali_filenames: validation_filenames})
print("this is validation iterator ")
for _ in range(5):
print(sess.run(itemid, feed_dict={vali_filenames: validation_filenames, handle: validation_handle}))
if __name__ == '__main__':
read_tensorflow_tfrecord_files()
參考資料:
- https://blog.csdn.net/briblue/article/details/80962728
3.2 dataset參考資料
在翻閱資料時,發現寫得不錯的一些博客:
- https://www.jianshu.com/p/91803a119f18
- https://irvingzhang0512.github.io/2018/04/19/tensorflow-api-2/
- http://www.feiguyunai.com/index.php/2017/12/25/pyhtonai-ml-dataprocess-datasetapi/
最後
樂於輸出幹貨的Java技術公眾號:Java3y。公眾號內有200多篇原創技術文章、海量視頻資源、精美腦圖,不妨來關註一下!
下一篇文章打算講講如何理解axis~
覺得我的文章寫得不錯,不妨點一下贊!
TensorFlow讀寫數據