1. 程式人生 > 實用技巧 >Tensorflow踩坑系列---資料讀取檔案佇列

Tensorflow踩坑系列---資料讀取檔案佇列

一:總括檔案讀取方式

1.供給資料(Feeding): 由佔位符placeholder代替資料,執行時使用feed_dict填入資料

2.預載入資料: 資料直接嵌入graph,由graph傳入session中執行

3.從檔案讀取資料: 在TensorFlow圖的起始, 讓一個輸入管線從檔案中讀取資料,這就是這篇文將要講的內容。

前兩種方法很方便,但是遇到大型資料的時候就會很吃力,即使是Feeding,中間環節的增加也是不小的開銷,比如資料型別轉換等等。 最優的方案就是在Graph定義好檔案讀取的方法,讓TF自己去從檔案中讀取資料,並解碼成可使用的樣本集。

對於大的資料集很難用numpy陣列儲存,所以這裡介紹一下Tensorflow讀取很大資料集的方法:string_input_producer()和slice_input_producer()。

這種直接從檔案中讀取資料的方式需要設計成Queue的方式才能較好的解決IO瓶頸的問題。
Queue機制有如下三個特點:

(1)producer-consumer pattern(生產消費模式)
(2)獨立於主執行緒執行
(3)非同步IO: reader.read(queue) tf.train.batch()

一:string_input_producer佇列使用(單個Reader、單檔案讀取)

import tensorflow as tf
IMAGE_DIR = "./Images/SourceImgs/"
QUEUE_DIR = "./Images/QueueImgs/"
FILELIST 
= ["1100.jpg","1101.jpg","1102.jpg","1104.jpg","1105.jpg", "1110.jpg","1114.jpg","1115.jpg","1116.jpg","1118.jpg"]

(一)獲取檔案列表

def getFileList(rootDir=IMAGE_DIR,files=FILELIST):
    fsl = []
    for fn in files:
        fsl.append(IMAGE_DIR+fn)
    return fsl

(二)使用佇列讀取檔案

with tf.Session() as sess:
    files_list 
= getFileList() #string_input_producer產生檔名佇列 filename_queue = tf.train.string_input_producer(files_list,shuffle=True,num_epochs=3) #reader從檔名佇列中讀取資料 reader = tf.WholeFileReader() key,value = reader.read(filename_queue) #返回檔名和檔案內容 sess.run(tf.local_variables_initializer()) #初始化上面的區域性變數 #啟動start_queue_runners之後,才會開始填充佇列 threads = tf.train.start_queue_runners(sess=sess) i = 1 while True: try: image_data = sess.run(value) with open(QUEUE_DIR+"%d.jpg"%i,"wb") as f: f.write(image_data) i+=1 except BaseException: print("read all files, numbers:%d"%i) break

(三)引數說明

tf.train.string_input_producer(files_list,shuffle=False,num_epochs=2)

shuffle=False:表示按序獲得檔案

num_epochs=2:表示會遍歷兩遍全部檔案,當我們不設定數值的時候,表示我們可以一直遍歷下去,會迴圈所有檔案

tf.train.string_input_producer(files_list,shuffle=True,num_epochs=3)

shuffle=False:表示打亂順序獲得檔案(是本輪所有檔案列表中亂序,不是全域性)

num_epochs=2:表示會遍歷三遍全部檔案

二:string_input_producer佇列使用(單個Reader、批檔案讀取)

import tensorflow as tf
IMAGE_DIR = "./Images/SourceImgs/"
QUEUE_DIR = "./Images/QueueImgs/"
FILELIST = ["1100.jpg","1101.jpg","1102.jpg","1104.jpg","1105.jpg",
           "1110.jpg","1114.jpg","1115.jpg","1116.jpg","1118.jpg"]
def getFileList(rootDir=IMAGE_DIR,files=FILELIST):
    fsl = []
    for fn in files:
        fsl.append(IMAGE_DIR+fn)
    return fsl

(一)按批次獲取檔案

files_list = getFileList()
#string_input_producer產生檔名佇列
filename_queue = tf.train.string_input_producer(files_list,shuffle=False,num_epochs=1)

def decode_img(fileQueue):
    #reader從檔名佇列中讀取資料
    reader = tf.WholeFileReader()
    key,value = reader.read(fileQueue) #返回檔名和檔案內容
    return value #返回一個檔案

img = decode_img(filename_queue)

image_batch = tf.train.batch([img],batch_size=8,num_threads=2,allow_smaller_final_batch=True) 

(二)執行緒呼叫

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer()) #初始化上面的全域性變數
    sess.run(tf.local_variables_initializer()) #初始化上面的區域性變數
    
    coord = tf.train.Coordinator()
    #啟動start_queue_runners之後,才會開始填充佇列
    threads = tf.train.start_queue_runners(sess=sess,coord=coord)
    j = 1
    try:
        while not coord.should_stop():
            images_data = sess.run(image_batch)
            print(images_data.shape)
            for img_data in images_data:
                with open(QUEUE_DIR+"%d.jpg"%j,"wb") as f:
                    f.write(img_data)
                j+=1
    except BaseException:
            print("read all files")
    finally:
        coord.request_stop() #將讀取檔案的執行緒關閉
    coord.join(threads) #執行緒回收,將讀取檔案的子執行緒加入主執行緒

(三)引數說明

tf.train.batch([img],batch_size=8,num_threads=2,allow_smaller_final_batch=True)

使用tf.train.batch,按序獲取:

batch_size每一個批次大小為8,
num_threads使用2執行緒讀取資料,雖然這裡只有一個Reader,但可以設定多執行緒,相應增加執行緒數會提高讀取速度,但並不是執行緒越多越好。
allow_smaller_final_batch,預設為false,剩餘資料小於batch_size則會被丟棄。

tf.train.shuffle_batch() 將佇列中資料打亂後再讀取出來,其他與batch方法類似。

需要設定:
capacity:佇列中元素的最大數量。
min_after_dequeue:出隊後佇列中元素的最小數量,用於確保元素的混合級別。

補充:

TensorFlow學習--tf.train.batch與tf.train.shuffle_batch

tf.train.string_input_producer()和tf.train.slice_input_producer()

string_input_producer:

載入圖片的reader是reader = tf.WholeFileReader()

key,value = reader.read(path_queue)其中key是檔名,value是byte型別的檔案流二進位制。

slice_input_producer:

載入圖片的reader使用tf.read_file(filename)直接讀取。這是兩者的一個不同之處!!!

TensorFlow基礎3:資料讀取的三種方式