tf.data
Dataset作為新的API,比以上兩種方法的速度都快,並且使用難度要遠遠低於使用Queues。tf.data中包含了兩個用於TensorFLow程序的接口:Dataset和Iterator。
Dataset(數據集) API 在 TensorFlow 1.4版本中已經從tf.contrib.data遷移到了tf.data之中,增加了對於Python的生成器的支持,官方強烈建議使用Dataset API 為 TensorFlow模型創建輸入管道,原因如下:
Dataset
Dataset表示一個元素的集合,可以看作函數式編程中的 lazy list, 元素是tensor tuple。創建Dataset的方式可以分為兩種,分別是:
Source
Apply transformation
Source
這裏 source 指的是從tf.Tensor對象創建Dataset,常見的方法又如下幾種:
tf.data.Dataset.from_tensors((features, labels)) tf.data.Dataset.from_tensor_slices((features, labels)) tf.data.TextLineDataset(filenames) tf.data.TFRecordDataset(filenames)
作用分別為:
1.從一個tensor tuple創建一個單元素的dataset;
2.從一個tensor tuple創建一個包含多個元素的dataset;
3.讀取一個文件名列表,將每個文件中的每一行作為一個元素,構成一個dataset;
4.讀取硬盤中的TFRecord格式文件,構造dataset。
Apply transformation
第二種方法就是通過轉化已有的dataset來得到新的dataset,TensorFLow tf.data.Dataset支持很多中變換,在這裏介紹常見的幾種:
dataset.map(lambda x: tf.decode_jpeg(x)) dataset.repeat(NUM_EPOCHS) dataset.batch(BATCH_SIZE)
以上三種方式分別表示了:使用map對dataset中的每個元素進行處理,這裏的例子是對圖片數據進行解碼;將dataset重復一定數目的次數用於多個epoch的訓練;將原來的dataset中的元素按照某個數量疊在一起,生成mini batch。
將以上代碼組合起來,我們可以得到一個常用的代碼片段:
# 從一個文件名列表讀取 TFRecord 構成 dataset dataset = TFRecordDataset(["file1.tfrecord", "file2.tfrecord"]) # 處理 string,將 string 轉化為 tf.Tensor 對象 dataset = dataset.map(lambda record: tf.parse_single_example(record)) # buffer 大小設置為 10000,打亂 dataset dataset = dataset.shuffle(10000) # dataset 將被用來訓練 100 個 epoch dataset = dataset.repeat(100) # 設置 batch size 為 128 dataset = dataset.batch(128)
Iterator
定義好了數據集以後可以通過Iterator接口來訪問數據集中的tensor tuple,iterator保持了數據在數據集中的位置,提供了訪問數據集中數據的方法。
可以通過調用 dataset 的 make iterator 方法來構建 iterator。
替換了place_holder,要用到數據的時候,在最開始定義iterator.get_next(),就取到了一個batch的元素了
API 支持以下四種 iterator,復雜程度遞增:
- one-shot
- initializable
- reinitializable
- feedable
one-shot
one-shot iterator 誰最簡單的一種 iterator,僅支持對整個數據集訪問一遍,不需要顯式的初始化。one-shot iterator 不支參數化。以下代碼使用tf.data.Dataset.range生成數據集,作用與 python 中的 range 類似。
dataset = tf.data.Dataset.range(100) iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() for i in range(100): value = sess.run(next_element) assert i == value
initializable
Initializable iterator 要求在使用之前顯式的通過調用iterator.initializer操作初始化,這使得在定義數據集時可以結合tf.placeholder傳入參數,如:
max_value = tf.placeholder(tf.int64, shape=[]) dataset = tf.data.Dataset.range(max_value) iterator = dataset.make_initializable_iterator() next_element = iterator.get_next() # Initialize an iterator over a dataset with 10 elements. sess.run(iterator.initializer, feed_dict={max_value: 10}) for i in range(10): value = sess.run(next_element) assert i == value # Initialize the same iterator over a dataset with 100 elements. sess.run(iterator.initializer, feed_dict={max_value: 100}) for i in range(100): value = sess.run(next_element) assert i == value
reinitializable
reinitializable iterator 可以被不同的 dataset 對象初始化,比如對於訓練集進行了shuffle的操作,對於驗證集則沒有處理,通常這種情況會使用兩個具有相同結構的dataset對象,如:
# Define training and validation datasets with the same structure. training_dataset = tf.data.Dataset.range(100).map( lambda x: x + tf.random_uniform([], -10, 10, tf.int64)) validation_dataset = tf.data.Dataset.range(50) # A reinitializable iterator is defined by its structure. We could use the # `output_types` and `output_shapes` properties of either `training_dataset` # or `validation_dataset` here, because they are compatible. iterator = tf.data.Iterator.from_structure(training_dataset.output_types, training_dataset.output_shapes) next_element = iterator.get_next() training_init_op = iterator.make_initializer(training_dataset) validation_init_op = iterator.make_initializer(validation_dataset) # 如果後面初始化的是這個,那麽就將循環這個數據集 # Run 20 epochs in which the training dataset is traversed, followed by the # validation dataset. for _ in range(20): # Initialize an iterator over the training dataset. sess.run(training_init_op) for _ in range(100): sess.run(next_element) # Initialize an iterator over the validation dataset. sess.run(validation_init_op) # 替換init_op,相當於替換數據集 for _ in range(50): sess.run(next_element)
feedable
feedable iterator 可以通過和tf.placeholder結合在一起,同通過feed_dict機制來選擇在每次調用tf.Session.run的時候選擇哪種Iterator。它提供了與 reinitilizable iterator 類似的功能,並且在切換數據集的時候不需要在開始的時候初始化iterator,還是上面的例子,通過tf.data.Iterator.from_string_handle來定義一個 feedable iterator,達到切換數據集的目的:
# Define training and validation datasets with the same structure. training_dataset = tf.data.Dataset.range(100).map( lambda x: x + tf.random_uniform([], -10, 10, tf.int64)).repeat() validation_dataset = tf.data.Dataset.range(50) # A feedable iterator is defined by a handle placeholder and its structure. We # could use the `output_types` and `output_shapes` properties of either # `training_dataset` or `validation_dataset` here, because they have # identical structure. handle = tf.placeholder(tf.string, shape=[]) iterator = tf.data.Iterator.from_string_handle( handle, training_dataset.output_types, training_dataset.output_shapes) next_element = iterator.get_next() # You can use feedable iterators with a variety of different kinds of iterator # (such as one-shot and initializable iterators). training_iterator = training_dataset.make_one_shot_iterator() validation_iterator = validation_dataset.make_initializable_iterator() # The `Iterator.string_handle()` method returns a tensor that can be evaluated # and used to feed the `handle` placeholder. training_handle = sess.run(training_iterator.string_handle()) validation_handle = sess.run(validation_iterator.string_handle()) # Loop forever, alternating between training and validation. while True: # Run 200 steps using the training dataset. Note that the training dataset is # infinite, and we resume from where we left off in the previous `while` loop # iteration. for _ in range(200): sess.run(next_element, feed_dict={handle: training_handle}) # Run one pass over the validation dataset. sess.run(validation_iterator.initializer) for _ in range(50): sess.run(next_element, feed_dict={handle: validation_handle})
使用實例:
def get_encodes(x): # x is `batch_size` of lines, each of which is a json object samples = [json.loads(l) for l in x] text = [s[‘fact‘] for s in samples] # get a client from available clients bc_client = bc_clients.pop() features = bc_client.encode(text) # after use, put it back bc_clients.append(bc_client) labels = [0 for _ in text] return features, labels data_node = (tf.data.TextLineDataset(train_fp).batch(batch_size) .map(lambda x: tf.py_func(get_encodes, [x], [tf.float32, tf.int64], name=‘bert_client‘), num_parallel_calls=num_parallel_calls) .map(lambda x, y: {‘feature‘: x, ‘label‘: y}) .make_one_shot_iterator().get_next())
tf.data