1. 程式人生 > >tf.data

tf.data

list func 兩個 format ict 編程 之前 訓練 initial

以往的TensorFLow模型數據的導入方法可以分為兩個主要方法,一種是使用feed_dict另外一種是使用TensorFlow中的Queues。前者使用起來比較靈活,可以利用Python處理各種輸入數據,劣勢也比較明顯,就是程序運行效率較低;後面一種方法的效率較高,但是使用起來較為復雜,靈活性較差。

Dataset作為新的API,比以上兩種方法的速度都快,並且使用難度要遠遠低於使用Queues。tf.data中包含了兩個用於TensorFLow程序的接口:Dataset和Iterator。

Dataset(數據集) API 在 TensorFlow 1.4版本中已經從tf.contrib.data遷移到了tf.data之中,增加了對於Python的生成器的支持,官方強烈建議使用Dataset API 為 TensorFlow模型創建輸入管道,原因如下:


Dataset

Dataset表示一個元素的集合,可以看作函數式編程中的 lazy list, 元素是tensor tuple。創建Dataset的方式可以分為兩種,分別是:

Source

Apply transformation
Source
這裏 source 指的是從tf.Tensor對象創建Dataset,常見的方法又如下幾種:

tf.data.Dataset.from_tensors((features, labels))
tf.data.Dataset.from_tensor_slices((features, labels))
tf.data.TextLineDataset(filenames)
tf.data.TFRecordDataset(filenames)

作用分別為:

  1.從一個tensor tuple創建一個單元素的dataset;

  2.從一個tensor tuple創建一個包含多個元素的dataset;

  3.讀取一個文件名列表,將每個文件中的每一行作為一個元素,構成一個dataset;

  4.讀取硬盤中的TFRecord格式文件,構造dataset。

Apply transformation

第二種方法就是通過轉化已有的dataset來得到新的dataset,TensorFLow tf.data.Dataset支持很多中變換,在這裏介紹常見的幾種:

dataset.map(lambda x: tf.decode_jpeg(x))
dataset.repeat(NUM_EPOCHS)    
dataset.batch(BATCH_SIZE)

以上三種方式分別表示了:使用map對dataset中的每個元素進行處理,這裏的例子是對圖片數據進行解碼;將dataset重復一定數目的次數用於多個epoch的訓練;將原來的dataset中的元素按照某個數量疊在一起,生成mini batch。

將以上代碼組合起來,我們可以得到一個常用的代碼片段:

# 從一個文件名列表讀取 TFRecord 構成 dataset
dataset = TFRecordDataset(["file1.tfrecord", "file2.tfrecord"])
# 處理 string,將 string 轉化為 tf.Tensor 對象
dataset = dataset.map(lambda record: tf.parse_single_example(record))
# buffer 大小設置為 10000,打亂 dataset
dataset = dataset.shuffle(10000)
# dataset 將被用來訓練 100 個 epoch
dataset = dataset.repeat(100)
# 設置 batch size 為 128
dataset = dataset.batch(128)

Iterator

定義好了數據集以後可以通過Iterator接口來訪問數據集中的tensor tuple,iterator保持了數據在數據集中的位置,提供了訪問數據集中數據的方法。

可以通過調用 dataset 的 make iterator 方法來構建 iterator。

替換了place_holder,要用到數據的時候,在最開始定義iterator.get_next(),就取到了一個batch的元素了

API 支持以下四種 iterator,復雜程度遞增:

  • one-shot
  • initializable
  • reinitializable
  • feedable

one-shot

one-shot iterator 誰最簡單的一種 iterator,僅支持對整個數據集訪問一遍,不需要顯式的初始化。one-shot iterator 不支參數化。以下代碼使用tf.data.Dataset.range生成數據集,作用與 python 中的 range 類似。

dataset = tf.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

for i in range(100):
  value = sess.run(next_element)
  assert i == value

initializable

Initializable iterator 要求在使用之前顯式的通過調用iterator.initializer操作初始化,這使得在定義數據集時可以結合tf.placeholder傳入參數,如:

max_value = tf.placeholder(tf.int64, shape=[])
dataset = tf.data.Dataset.range(max_value)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

# Initialize an iterator over a dataset with 10 elements.
sess.run(iterator.initializer, feed_dict={max_value: 10})
for i in range(10):
  value = sess.run(next_element)
  assert i == value

# Initialize the same iterator over a dataset with 100 elements.
sess.run(iterator.initializer, feed_dict={max_value: 100})
for i in range(100):
  value = sess.run(next_element)
  assert i == value

reinitializable

reinitializable iterator 可以被不同的 dataset 對象初始化,比如對於訓練集進行了shuffle的操作,對於驗證集則沒有處理,通常這種情況會使用兩個具有相同結構的dataset對象,如:

# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(
    lambda x: x + tf.random_uniform([], -10, 10, tf.int64))
validation_dataset = tf.data.Dataset.range(50)

# A reinitializable iterator is defined by its structure. We could use the
# `output_types` and `output_shapes` properties of either `training_dataset`
# or `validation_dataset` here, because they are compatible.
iterator = tf.data.Iterator.from_structure(training_dataset.output_types,
                                   training_dataset.output_shapes)
next_element = iterator.get_next()

training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset) # 如果後面初始化的是這個,那麽就將循環這個數據集

# Run 20 epochs in which the training dataset is traversed, followed by the
# validation dataset.
for _ in range(20):
  # Initialize an iterator over the training dataset.
  sess.run(training_init_op)
  for _ in range(100):
    sess.run(next_element)

  # Initialize an iterator over the validation dataset.
  sess.run(validation_init_op) # 替換init_op,相當於替換數據集
  for _ in range(50):
    sess.run(next_element)

feedable

feedable iterator 可以通過和tf.placeholder結合在一起,同通過feed_dict機制來選擇在每次調用tf.Session.run的時候選擇哪種Iterator。它提供了與 reinitilizable iterator 類似的功能,並且在切換數據集的時候不需要在開始的時候初始化iterator,還是上面的例子,通過tf.data.Iterator.from_string_handle來定義一個 feedable iterator,達到切換數據集的目的:

# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(
    lambda x: x + tf.random_uniform([], -10, 10, tf.int64)).repeat()
validation_dataset = tf.data.Dataset.range(50)

# A feedable iterator is defined by a handle placeholder and its structure. We
# could use the `output_types` and `output_shapes` properties of either
# `training_dataset` or `validation_dataset` here, because they have
# identical structure.
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
    handle, training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()

# You can use feedable iterators with a variety of different kinds of iterator
# (such as one-shot and initializable iterators).
training_iterator = training_dataset.make_one_shot_iterator()
validation_iterator = validation_dataset.make_initializable_iterator()

# The `Iterator.string_handle()` method returns a tensor that can be evaluated
# and used to feed the `handle` placeholder.
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())

# Loop forever, alternating between training and validation.
while True:
  # Run 200 steps using the training dataset. Note that the training dataset is
  # infinite, and we resume from where we left off in the previous `while` loop
  # iteration.
  for _ in range(200):
    sess.run(next_element, feed_dict={handle: training_handle})

  # Run one pass over the validation dataset.
  sess.run(validation_iterator.initializer)
  for _ in range(50):
    sess.run(next_element, feed_dict={handle: validation_handle})

使用實例:

def get_encodes(x):
    # x is `batch_size` of lines, each of which is a json object
    samples = [json.loads(l) for l in x]
    text = [s[fact] for s in samples]
    # get a client from available clients
    bc_client = bc_clients.pop()
    features = bc_client.encode(text)
    # after use, put it back
    bc_clients.append(bc_client)
    labels = [0 for _ in text]
    return features, labels


data_node = (tf.data.TextLineDataset(train_fp).batch(batch_size)
             .map(lambda x: tf.py_func(get_encodes, [x], [tf.float32, tf.int64], name=bert_client), num_parallel_calls=num_parallel_calls)
             .map(lambda x, y: {feature: x, label: y})
             .make_one_shot_iterator().get_next())

tf.data