1. 程式人生 > >tensorflow數據加載方式

tensorflow數據加載方式

tensorflow 數據加載

tensorflow當前具有三種讀取數據的方式:
1.預加載(preloaded):在構建tensorflow流圖時直接定義常量數據,由於數據是直接鑲嵌在流圖中,所以當數據量很大時將占用大量內存

import tensorflow as tf
a = tf.constant([1,2,3],name=‘input_a‘)
b = tf.constant([4,5,6],name=‘input_b‘)
c = tf.add(a,b,name=‘sums‘)
sess = tf.Session()
x = sess.run(c)
print(x)

2.填充(feeding):將python產生的數據直接填充到後端,這種方式同樣存在數據量大時消耗內存的問題,同時數據類型轉換也會增加一些開銷

import tensorflow as tf
a = tf.placeholder(tf.int16)
b = tf.placeholder(tf.int16)
c = tf.add(a,b)
p_a = [1,2,3]
p_b = [4,5,6]
with tf.Session() as sess:
    print(sess.run(c, feed_dict={a:p_a, b:p_b}))

3.從文件讀取(reading from file):相較於上面兩種,這種方式處理量大的數據具有很大優勢。tensorflow在從文件中讀取數據時主要分兩步:
(1)將數據寫入TFRecords二進制文件;

‘‘‘創建轉換函數,將數據填入到tf.train.Example協議緩沖區中,同時將緩沖區序列化為字符串,
  再通過tf.python_io.TFRecordWriter寫入TFRecords文件‘‘‘
import os
import tensorflow as tf
def int64_feature(data):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[data]))
def bytes_feature(data):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[data]))
def convert_tfrecords(data, name):
    images = data.images
    labels = data.labels
    num_examples = data.num_examples
    if images.shape[0] != num_examples:
        raise ValueError(u‘圖片數量與標簽數量不一致,分別為%d和%d‘ %(images.shape[0],num_examples))
    rows = images.shape[1]
    width = images.shape[2]
    depth = images.shape[3]
    filename = os.path.join(os.path.dirname(__file__), name + ‘.tfrecores‘)
    writer = tf.python_io.TFRecoredWriter(filename)
    for i in range(num_examples):
        image_raw = images[i].tostring()
        example = tf.train.Example(features = tf.train.Features(feature = {
                            ‘height‘: int64_feature(rows), ‘width‘:int64_feature(width),
                            ‘depth‘:int64_feature(depth),‘label‘:int64_feature(labels),
                            ‘image_raw‘:bytes_feature(image_raw)}))
        writer.write(example.SerializeToString())
    writer.close()

(2)使用隊列從二進制文件中讀取數據。

tensorflow數據加載方式