1. 程式人生 > >【TensorFlow動手玩】資料匯入2

【TensorFlow動手玩】資料匯入2

簡介

本文介紹TensorFlow的第二種資料匯入方法。

為了保持高效,這種方法稍顯繁瑣。分為如下幾個步驟:
- 把所有樣本寫入二進位制檔案(只執行一次)
- 建立Tensor,從二進位制檔案讀取一個樣本
- 建立Tensor,從二進位制檔案隨機讀取一個mini-batch
- 把mini-batchTensor傳入網路作為輸入節點。

二進位制檔案

使用tf.python_io.TFRecordWriter建立一個專門儲存tensorflow資料的writer,副檔名為’.tfrecord’。
該檔案中依次儲存著序列化的tf.train.Example型別的樣本。

writer = tf.python_io.TFRecordWriter('/tmp/data.tfrecord'
) for i in range(0, 10): # 建立樣本example # ... serialized = example.SerializeToString() # 序列化 writer.write(serialized) # 寫入檔案 writer.close()

每一個examplefeature成員變數是一個dict,儲存一個樣本的不同部分(例如影象畫素+類標)。以下例子的樣本中包含三個鍵a,b,c

    # 建立樣本example
    a_data = 0.618 + i         # float
    b_data = [2016
+ i, 2017+i] # int64 c_data = numpy.array([[0, 1, 2],[3, 4, 5]]) + i # bytes c_data = c_data.astype(numpy.uint8) c_raw = c.tostring() # 轉化成字串 example = tf.train.Example( features=tf.train.Features( feature={ 'a': tf.train.Feature( float_list=tf.train.FloatList(value=[a_data]) # 方括號表示輸入為list
), 'b': tf.train.Feature( int64_list=tf.train.Int64List(value=b_data) # b_data本身就是列表 ), 'c': tf.train.Feature( bytes_list=tf.train.BytesList(value=[c_raw]) ) } ) )

dict成員的值部分接受三種類型資料:
- tf.train.FloatList:列表每個元素為float。例如a
- tf.train.Int64List:列表每個元素為int64。例如b
- tf.train.BytesList:列表每個元素為string。例如c

第三種類型尤其適合影象樣本。注意在轉成字串之前要設定為uint8型別。

讀取一個樣本

接下來,我們定義一個函式,建立“從檔案中讀一個樣本”操作,返回結果Tensor

def read_single_sample(filename):
    # 讀取樣本example的每個成員a,b,c
    # ...
    return a, b, c

首先建立讀檔案佇列,使用tf.TFRecordReader從檔案佇列讀入一個序列化的樣本。

    # 讀取樣本example的每個成員a,b,c
    filename_queue = tf.train.string_input_producer([filename], num_epochs=None)    # 不限定讀取數量
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)

如果樣本量很大,可以分成若干檔案,把檔名列表傳入tf.train.string_input_producer
和剛才的writer不同,這個reader是符號化的,只有在sess中run才會執行。

接下來解析符號化的樣本

    # get feature from serialized example
    features = tf.parse_single_example(
        serialized_example,
        features={
            'a': tf.FixedLenFeature([], tf.float32),    #0D, 標量
            'b': tf.FixedLenFeature([2], tf.int64),   # 1D,長度為2
            'c': tf.FixedLenFeature([], tf.string)  # 0D, 標量
        }
    )
    a = features['a']
    b = features['b']
    c_raw = features['c']
    c = tf.decode_raw(c_raw, tf.uint8)
    c = tf.reshape(c, [2, 3])

對於BytesList,要重新進行解碼,把string型別的0維Tensor變成uint8型別的1維Tensor

讀取mini-batch

使用tf.train.shuffle_batch將前述a,b,c隨機化,獲得mini-batchTensor

a_batch, b_batch, c_batch = tf.train.shuffle_batch([a, b, c], batch_size=2, capacity=200, min_after_dequeue=100, num_threads=2)

使用

建立一個session並初始化:

# sess
sess = tf.Session()
init = tf.initialize_all_variables()
sess.run(init)
tf.train.start_queue_runners(sess=sess)

由於使用了讀檔案佇列,所以要start_queue_runners

每一次執行,會隨機生成一個mini-batch樣本:

a_val, b_val, c_val = sess.run([a_batch, b_batch, c_batch])
a_val, b_val, c_val = sess.run([a_batch, b_batch, c_batch])

這樣的mini-batch可以作為網路的輸入節點使用。

總結

如果想進一步瞭解例子中的佇列機制,請參看這篇文章

完整程式碼如下:

import tensorflow as tf
import numpy

def write_binary():
    writer = tf.python_io.TFRecordWriter('/tmp/data.tfrecord')

    for i in range(0, 2):
        a = 0.618 + i
        b = [2016 + i, 2017+i]
        c = numpy.array([[0, 1, 2],[3, 4, 5]]) + i
        c = c.astype(numpy.uint8)
        c_raw = c.tostring()

        example = tf.train.Example(
            features=tf.train.Features(
                feature={
                    'a': tf.train.Feature(
                        float_list=tf.train.FloatList(value=[a])
                    ),

                    'b': tf.train.Feature(
                        int64_list=tf.train.Int64List(value=b)
                    ),
                    'c': tf.train.Feature(
                        bytes_list=tf.train.BytesList(value=[c_raw])
                    )
                }
            )
        )
        serialized = example.SerializeToString()
        writer.write(serialized)

    writer.close()

def read_single_sample(filename):
    # output file name string to a queue
    filename_queue = tf.train.string_input_producer([filename], num_epochs=None)

    # create a reader from file queue
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)

    # get feature from serialized example

    features = tf.parse_single_example(
        serialized_example,
        features={
            'a': tf.FixedLenFeature([], tf.float32),
            'b': tf.FixedLenFeature([2], tf.int64),
            'c': tf.FixedLenFeature([], tf.string)
        }
    )

    a = features['a']

    b = features['b']

    c_raw = features['c']
    c = tf.decode_raw(c_raw, tf.uint8)
    c = tf.reshape(c, [2, 3])

    return a, b, c

#-----main function-----
if 1:
    write_binary()
else:
    # create tensor
    a, b, c = read_single_sample('/tmp/data.tfrecord')
    a_batch, b_batch, c_batch = tf.train.shuffle_batch([a, b, c], batch_size=3, capacity=200, min_after_dequeue=100, num_threads=2)

    queues = tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS)

    # sess
    sess = tf.Session()
    init = tf.initialize_all_variables()
    sess.run(init)

    tf.train.start_queue_runners(sess=sess)
    a_val, b_val, c_val = sess.run([a_batch, b_batch, c_batch])
    print(a_val, b_val, c_val)
    a_val, b_val, c_val = sess.run([a_batch, b_batch, c_batch])
    print(a_val, b_val, c_val)