1. 程式人生 > >cifar10資料的讀取

cifar10資料的讀取

這裡寫圖片描述

cifar10資料集檔案結構如圖所示,其中data_batch_1~5.bin是訓練集,每個檔案包含10000個樣本,test_batch.bin是測試集,包含10000個樣本。

開啟任意一個檔案,發現是一堆二進位制資料,

這裡寫圖片描述
其中一個樣本由3037個位元組組成,其中第一個位元組是label,剩餘3036(32*32*3)個位元組是image,每個檔案由連續的10000個樣本組成,具體的讀取過程參考下面程式碼及註釋。

#獲取image和label
def get_input():
    #檔名佇列
    filenames = tf.train.match_filenames_once(DATA_DIR+'/data_batch_*'
) filename_queue = tf.train.string_input_producer(filenames) #cifar10的資料格式: #一個樣本由3037個位元組組成,其中第一個位元組是label,剩餘3036(32*32*3)個位元組是image #每個檔案由連續的10000個樣本組成,共5個檔案 image_bytes = IMAGE_SIZE * IMAGE_SIZE * IMAGE_DEPTH record_bytes = image_bytes + LABEL_BYTES #使用FixedLengthRecordReader讀取樣本,每次讀取一個
reader = tf.FixedLengthRecordReader(record_bytes=record_bytes) #獲取樣本的值 _,value = reader.read(filename_queue) #讀出來的樣本為二進位制的字串格式,轉化為uint8的格式 raw_value = tf.decode_raw(value,tf.uint8) #劃分label和image labels = tf.cast(tf.strided_slice(raw_value,[0],[1]),tf.int32) #由於image是按照(depth,height,width)的格式儲存的,因此讀出來後還要將其轉化為(height,width,depth)的格式
images = tf.reshape( tf.strided_slice(raw_value,[LABEL_BYTES],[LABEL_BYTES+image_bytes]), [IMAGE_DEPTH,IMAGE_SIZE,IMAGE_SIZE] ) images = tf.transpose(images,[1,2,0]) images = tf.cast(images,tf.float32) #資料型別:label是int32,image是範圍為0-1的float32 #標準化處理:減去平均值併除以方差,使得樣本均值為0,方差為1 standard_images = tf.image.per_image_standardization(images) #官方bug,得加上 standard_images.set_shape([RESIZE_SIZE,RESIZE_SIZE,3]) labels.set_shape([1]) return standard_images,labels