【TensorFlow】TFRecord資料集的製作:讀取、顯示及程式碼詳解
在跑通了官網的mnist和cifar10資料之後,筆者嘗試著製作自己的資料集,並儲存,讀入,顯示。 TensorFlow可以支援cifar10的資料格式, 也提供了標準的TFRecord 格式。
tensorflow 讀取資料, 官網提供了以下三種方法:
1 Feeding: 在tensorflow程式執行的每一步, 用python程式碼線上提供資料;
2 Reader : 在一個計算圖(tf.graph)的開始前,將檔案讀入到流(queue)中;
3 在宣告tf.variable變數或numpy陣列時儲存資料。受限於記憶體大小,適用於資料較小的情況;
在本文,主要介紹第二種方法,利用tf.record標準介面來讀入檔案
準備圖片資料
筆者找了2類狗的圖片, 哈士奇和吉娃娃, 全部 resize成128 * 128大小
如下圖, 儲存地址為D:\Python\data\dog
每類中有10張圖片
現在利用這2 類 20張圖片製作TFRecord檔案
製作TFRECORD檔案
1 先聊一下tfrecord, 這是一種將影象資料和標籤放在一起的二進位制檔案,能更好的利用記憶體,在tensorflow中快速的複製,移動,讀取,儲存 等等..
這裡注意,tfrecord會根據你選擇輸入檔案的類,自動給每一類打上同樣的標籤
如在本例中,只有0,1 兩類
2 先上“製作TFRecord檔案”的程式碼,註釋附詳解
import os import tensorflow as tf from PIL import Image #注意Image,後面會用到 import matplotlib.pyplot as plt import numpy as np cwd='D:\Python\data\dog\\' classes={'husky','chihuahua'} #人為 設定 2 類 writer= tf.python_io.TFRecordWriter("dog_train.tfrecords") #要生成的檔案 for index,name in enumerate(classes): class_path=cwd+name+'\\' for img_name in os.listdir(class_path): img_path=class_path+img_name #每一個圖片的地址 img=Image.open(img_path) img= img.resize((128,128)) img_raw=img.tobytes()#將圖片轉化為二進位制格式 example = tf.train.Example(features=tf.train.Features(feature={ "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])), 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])) })) #example物件對label和image資料進行封裝 writer.write(example.SerializeToString()) #序列化為字串 writer.close()
tf.train.Example 協議記憶體塊包含了Features欄位,通過feature將圖片的二進位制資料和label進行統一封裝, 然後將example協議記憶體塊轉化為字串, tf.python_io.TFRecordWriter 寫入到TFRecords檔案中。執行完這段程式碼後,會生成dog_train.tfrecords 檔案,如下圖
讀取TFRECORD檔案
在製作完tfrecord檔案後, 將該檔案讀入到資料流中。
程式碼如下
def read_and_decode(filename): # 讀入dog_train.tfrecords
filename_queue = tf.train.string_input_producer([filename])#生成一個queue佇列
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)#返回檔名和檔案
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw' : tf.FixedLenFeature([], tf.string),
})#將image資料和label取出來
img = tf.decode_raw(features['img_raw'], tf.uint8)
img = tf.reshape(img, [128, 128, 3]) #reshape為128*128的3通道圖片
img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 #在流中丟擲img張量
label = tf.cast(features['label'], tf.int32) #在流中丟擲label張量
return img, label
注意,feature的屬性“label”和“img_raw”名稱要和製作時統一 ,返回的img資料和label資料一一對應。返回的img和label是2個 tf 張量,print出來 如下圖
顯示tfrecord格式的圖片
有些時候我們希望檢查分類是否有誤,或者在之後的網路訓練過程中可以監視,輸出圖片,來觀察分類等操作的結果,那麼我們就可以session回話中,將tfrecord的圖片從流中讀取出來,再儲存。 緊跟著一開始的程式碼寫:
filename_queue = tf.train.string_input_producer(["dog_train.tfrecords"]) #讀入流中
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue) #返回檔名和檔案
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw' : tf.FixedLenFeature([], tf.string),
}) #取出包含image和label的feature物件
image = tf.decode_raw(features['img_raw'], tf.uint8)
image = tf.reshape(image, [128, 128, 3])
label = tf.cast(features['label'], tf.int32)
with tf.Session() as sess: #開始一個會話
init_op = tf.initialize_all_variables()
sess.run(init_op)
coord=tf.train.Coordinator()
threads= tf.train.start_queue_runners(coord=coord)
for i in range(20):
example, l = sess.run([image,label])#在會話中取出image和label
img=Image.fromarray(example, 'RGB')#這裡Image是之前提到的
img.save(cwd+str(i)+'_''Label_'+str(l)+'.jpg')#存下圖片
print(example, l)
coord.request_stop()
coord.join(threads)
程式碼執行完後, 從tfrecord中取出的檔案被儲存了。如下圖:
在這裡我們可以看到,圖片檔名的第一個數字表示在流中的順序(這裡沒有用shuffle), 第二個數字則是 每個圖片的label,吉娃娃都為0,哈士奇都為1。 由此可見,我們一開始製作tfrecord檔案時,圖片分類正確。