tensorflow batch函式實現
阿新 • • 發佈:2018-11-16
def get_batch(image, label, image_W, image_H, batch_size, capacity): image = tf.cast(image, tf.string) label = tf.cast(label, tf.int32) #生成列隊 input_queue = tf.train.slice_input_producer([image, label]) label = input_queue[1] image_contents = tf.read_file(input_queue[0]) image = tf.image.decode_jpeg(image_contents, channels=3) image = tf.image.resize_image_with_crop_or_pad(image, image_W, image_H) a = tf.reduce_mean(image) image = tf.subtract(image,a) image_batch, label_batch = tf.train.batch([image, label], batch_size= batch_size, num_threads= 16, capacity = capacity) label_batch = tf.reshape(label_batch, [batch_size]) image_batch = tf.cast(image_batch, tf.float32) return image_batch, label_batch
sess.run(tf.global_variables_initializer()) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: for step in np.arange(MAX_STEP): if coord.should_stop(): break ……………… ……………… ……………… finally: # 使用tf.train.Coordinator來停止所有執行緒 coord.request_stop() coord.join(threads) sess.close()
tensorflow資料的讀取形式:
使用一個執行緒將硬碟中的圖片資料讀入到一個佇列中,另一個執行緒負責計算任務,所需資料直接從該佇列中獲取。
存入列隊的函式:
input_queue = tf.train.slice_input_producer(tensor_list, num_epochs=None, shuffle=True,seed=None, capacity=capacity, shared_name=None, name=None) ''' tensor_list:輸入的tensor列表 num_epochs:可選引數,是一個整數值,代表迭代的次數,如果設定None,生成器可以無限次遍歷tensor列表,如果設定為 num_epochs=N,生成器只能遍歷tensor列表N次。 shuffle:bool型別,設定是否打亂樣本的順序 seed:可選的整數,是生成隨機數的種子 capacity:設定tensor列表的容量 share_name:可選引數,如果設定一個‘shared_name’,則在不同的上下文環境(Session)中可以通過這個名字共享生成的tensor name:操作的名稱 '''
在得到列隊後,再呼叫tf.train.start_queue_runners 函式來啟動執行檔名佇列填充的執行緒,之後計算單元才可以把資料讀出來。
但在此之前,需要先呼叫 tf.train.Coordinator() 來建立一個執行緒協調器,用來管理之後在Session中啟動的所有執行緒。
tf.train.start_queue_runners()
啟動入隊執行緒,這之後才能讀取資料到用於之後的計算。
image_batch, label_batch = tf.train.batch(tensors, batch_size, num_threads=1,
capacity=capacity,enqueue_many=False, shapes=None,
dynamic_pad=False,allow_smaller_final_batch=False,
shared_name=None, name=None)
'''
tensors:tensor序列或tensor字典,可以是含有單個樣本的序列
batch_size: 生成的batch的大小
num_threads:執行tensor入隊操作的執行緒數量,可以設定使用多個執行緒同時並行執行
capacity: 定義生成的tensor序列的最大容量
enqueue_many: 定義第一個傳入引數tensors是多個tensor組成的序列,還是單個tensor
shapes: 可選引數,預設是推測出的傳入的tensor的形狀
dynamic_pad: 定義是否允許輸入的tensors具有不同的形狀,設定為True,會把輸入的具有不同形狀的tensor歸一化到相同的形狀
allow_smaller_final_batch: 設定為True,表示在tensor佇列中剩下的tensor數量不夠一個batch_size的情況下,允許最後一個batch的數量少於batch_size, 設定為False,則不管什麼情況下,生成的batch都擁有batch_size個樣本
shared_name: 可選引數,設定生成的tensor序列在不同的Session中的共享名稱
name: 操作的名稱
'''
如果tf.train.batch的第一個引數 tensors 傳入的是tenor列表或者字典,返回的是tensor列表或字典,如果傳入的是隻含有一個元素的列表,返回的是單個的tensor,而不是一個列表。