tensorflow使用range_input_producer多執行緒讀取資料例項
阿新 • • 發佈:2020-01-21
先放關鍵程式碼:
i = tf.train.range_input_producer(NUM_EXPOCHES,num_epochs=1,shuffle=False).dequeue() inputs = tf.slice(array,[i * BATCH_SIZE],[BATCH_SIZE])
原理解析:
第一行會產生一個佇列,佇列包含0到NUM_EXPOCHES-1的元素,如果num_epochs有指定,則每個元素只產生num_epochs次,否則迴圈產生。shuffle指定是否打亂順序,這裡shuffle=False表示佇列的元素是按0到NUM_EXPOCHES-1的順序儲存。在Graph執行的時候,每個執行緒從佇列取出元素,假設值為i,然後按照第二行程式碼切出array的一小段資料作為一個batch。例如NUM_EXPOCHES=3,如果num_epochs=2,則佇列的內容是這樣子;
0,1,2,2
佇列只有6個元素,這樣在訓練的時候只能產生6個batch,迭代6次以後訓練就結束。
如果num_epochs不指定,則佇列內容是這樣子:
0,2...
佇列可以一直生成元素,訓練的時候可以產生無限的batch,需要自己控制什麼時候停止訓練。
下面是完整的演示程式碼。
資料檔案test.txt內容:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
main.py內容:
import tensorflow as tf import codecs BATCH_SIZE = 6 NUM_EXPOCHES = 5 def input_producer(): array = codecs.open("test.txt").readlines() array = map(lambda line: line.strip(),array) i = tf.train.range_input_producer(NUM_EXPOCHES,shuffle=False).dequeue() inputs = tf.slice(array,[BATCH_SIZE]) return inputs class Inputs(object): def __init__(self): self.inputs = input_producer() def main(*args,**kwargs): inputs = Inputs() init = tf.group(tf.initialize_all_variables(),tf.initialize_local_variables()) sess = tf.Session() coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess,coord=coord) sess.run(init) try: index = 0 while not coord.should_stop() and index<10: datalines = sess.run(inputs.inputs) index += 1 print("step: %d,batch data: %s" % (index,str(datalines))) except tf.errors.OutOfRangeError: print("Done traing:-------Epoch limit reached") except KeyboardInterrupt: print("keyboard interrput detected,stop training") finally: coord.request_stop() coord.join(threads) sess.close() del sess if __name__ == "__main__": main()
輸出:
step: 1,batch data: ['1' '2' '3' '4' '5' '6'] step: 2,batch data: ['7' '8' '9' '10' '11' '12'] step: 3,batch data: ['13' '14' '15' '16' '17' '18'] step: 4,batch data: ['19' '20' '21' '22' '23' '24'] step: 5,batch data: ['25' '26' '27' '28' '29' '30'] Done traing:-------Epoch limit reached
如果range_input_producer去掉引數num_epochs=1,則輸出:
step: 1,batch data: ['25' '26' '27' '28' '29' '30'] step: 6,batch data: ['1' '2' '3' '4' '5' '6'] step: 7,batch data: ['7' '8' '9' '10' '11' '12'] step: 8,batch data: ['13' '14' '15' '16' '17' '18'] step: 9,batch data: ['19' '20' '21' '22' '23' '24'] step: 10,batch data: ['25' '26' '27' '28' '29' '30']
有一點需要注意,檔案總共有35條資料,BATCH_SIZE = 6表示每個batch包含6條資料,NUM_EXPOCHES = 5表示產生5個batch,如果NUM_EXPOCHES =6,則總共需要36條資料,就會報如下錯誤:
InvalidArgumentError (see above for traceback): Expected size[0] in [0,5],but got 6 [[Node: Slice = Slice[Index=DT_INT32,T=DT_STRING,_device="/job:localhost/replica:0/task:0/cpu:0"](Slice/input,Slice/begin/_5,Slice/size)]]
錯誤資訊的意思是35/BATCH_SIZE=5,即NUM_EXPOCHES 的取值能只能在0到5之間。
以上這篇tensorflow使用range_input_producer多執行緒讀取資料例項就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。