1. 程式人生 > 程式設計 >tensorflow tf.train.batch之資料批量讀取方式

tensorflow tf.train.batch之資料批量讀取方式

在進行大量資料訓練神經網路的時候,可能需要批量讀取資料。於是參考了這篇文章的程式碼,結果發現數據一直批量迴圈輸出,不會在資料的末尾自動停止。

然後發現這篇博文說slice_input_producer()這個函式有一個形參num_epochs,通過設定它的值就可以控制全部資料迴圈輸出幾次。

於是我設定之後出現以下的報錯:

tensorflow.python.framework.errors_impl.FailedPreconditionError: Attempting to use uninitialized value input_producer/input_producer/limit_epochs/epochs

     [[Node: input_producer/input_producer/limit_epochs/CountUpTo = CountUpTo[T=DT_INT64,_class=["loc:@input_producer/input_producer/limit_epochs/epochs"],limit=2,_device="/job:localhost/replica:0/task:0/cpu:0"](input_producer/input_producer/limit_epochs/epochs)]]

找了好久,都不知道為什麼會錯,於是只好去看看slice_input_producer()函式的原始碼,結果在原始碼中發現作者說這個num_epochs如果不是空的話,就是一個區域性變數,需要先呼叫global_variables_initializer()函式初始化。

於是我呼叫了之後,一切就正常了,特此記錄下來,希望其他人遇到的時候能夠及時找到原因。

哈哈,這是筆者第一次通過閱讀原始碼解決了問題,心情還是有點小激動。啊啊,扯遠了,上最終成功的程式碼:

import pandas as pd
import numpy as np
import tensorflow as tf


def generate_data():
  num = 25
  label = np.asarray(range(0,num))
  images = np.random.random([num,5])
  print('label size :{},image size {}'.format(label.shape,images.shape))
  return images,label

def get_batch_data():
  label,images = generate_data()
  input_queue = tf.train.slice_input_producer([images,label],shuffle=False,num_epochs=2)
  image_batch,label_batch = tf.train.batch(input_queue,batch_size=5,num_threads=1,capacity=64,allow_smaller_final_batch=False)
  return image_batch,label_batch


images,label = get_batch_data()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())#就是這一行
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess,coord)
try:
  while not coord.should_stop():
    i,l = sess.run([images,label])
    print(i)
    print(l)
except tf.errors.OutOfRangeError:
  print('Done training')
finally:
  coord.request_stop()
coord.join(threads)
sess.close()

以上這篇tensorflow tf.train.batch之資料批量讀取方式就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。