1. 程式人生 > >tensorflow batch函式實現

tensorflow batch函式實現

def get_batch(image, label, image_W, image_H, batch_size, capacity):
    
    image = tf.cast(image, tf.string)
    label = tf.cast(label, tf.int32)
    #生成列隊
    input_queue = tf.train.slice_input_producer([image, label])
    
    label = input_queue[1]
    image_contents = tf.read_file(input_queue[0])
    image = tf.image.decode_jpeg(image_contents, channels=3)
    
    image = tf.image.resize_image_with_crop_or_pad(image, image_W, image_H)
    a = tf.reduce_mean(image)
    image = tf.subtract(image,a)
    
    image_batch, label_batch = tf.train.batch([image, label],
                                              batch_size= batch_size,
                                              num_threads= 16, 
                                              capacity = capacity)
 
    
    label_batch = tf.reshape(label_batch, [batch_size])
    image_batch = tf.cast(image_batch, tf.float32)
    
    return image_batch, label_batch

 

    sess.run(tf.global_variables_initializer())
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    
    try:
        for step in np.arange(MAX_STEP):
            if coord.should_stop():
                    break
            ………………
            ………………
            ………………
    finally:
        # 使用tf.train.Coordinator來停止所有執行緒  
        coord.request_stop()
     
    coord.join(threads)
    sess.close()

 tensorflow資料的讀取形式:

使用一個執行緒將硬碟中的圖片資料讀入到一個佇列中,另一個執行緒負責計算任務,所需資料直接從該佇列中獲取。

存入列隊的函式:

input_queue = tf.train.slice_input_producer(tensor_list, num_epochs=None, shuffle=True,seed=None,
                                capacity=capacity, shared_name=None, name=None)
'''
tensor_list:輸入的tensor列表

num_epochs:可選引數,是一個整數值,代表迭代的次數,如果設定None,生成器可以無限次遍歷tensor列表,如果設定為 num_epochs=N,生成器只能遍歷tensor列表N次。

shuffle:bool型別,設定是否打亂樣本的順序

seed:可選的整數,是生成隨機數的種子

capacity:設定tensor列表的容量

share_name:可選引數,如果設定一個‘shared_name’,則在不同的上下文環境(Session)中可以通過這個名字共享生成的tensor

name:操作的名稱
'''

 

在得到列隊後,再呼叫tf.train.start_queue_runners 函式來啟動執行檔名佇列填充的執行緒,之後計算單元才可以把資料讀出來。

但在此之前,需要先呼叫 tf.train.Coordinator() 來建立一個執行緒協調器,用來管理之後在Session中啟動的所有執行緒。

tf.train.start_queue_runners()

啟動入隊執行緒,這之後才能讀取資料到用於之後的計算。

 

image_batch, label_batch = tf.train.batch(tensors, batch_size, num_threads=1,
                                         capacity=capacity,enqueue_many=False, shapes=None,
                            dynamic_pad=False,allow_smaller_final_batch=False, 
                            shared_name=None, name=None)
'''
tensors:tensor序列或tensor字典,可以是含有單個樣本的序列
batch_size: 生成的batch的大小
num_threads:執行tensor入隊操作的執行緒數量,可以設定使用多個執行緒同時並行執行
capacity: 定義生成的tensor序列的最大容量
enqueue_many: 定義第一個傳入引數tensors是多個tensor組成的序列,還是單個tensor
shapes: 可選引數,預設是推測出的傳入的tensor的形狀
dynamic_pad: 定義是否允許輸入的tensors具有不同的形狀,設定為True,會把輸入的具有不同形狀的tensor歸一化到相同的形狀
allow_smaller_final_batch: 設定為True,表示在tensor佇列中剩下的tensor數量不夠一個batch_size的情況下,允許最後一個batch的數量少於batch_size, 設定為False,則不管什麼情況下,生成的batch都擁有batch_size個樣本
shared_name: 可選引數,設定生成的tensor序列在不同的Session中的共享名稱
name: 操作的名稱
'''

如果tf.train.batch的第一個引數 tensors 傳入的是tenor列表或者字典,返回的是tensor列表或字典,如果傳入的是隻含有一個元素的列表,返回的是單個的tensor,而不是一個列表。