TensorFlow TFRecord資料集的生成與顯示
TFRecord
TensorFlow提供了TFRecord的格式來統一儲存資料,TFRecord格式是一種將影象資料和標籤放在一起的二進位制檔案,能更好的利用記憶體,在tensorflow中快速的複製,移動,讀取,儲存 等等。
TFRecords檔案包含了tf.train.Example 協議記憶體塊(protocol buffer)(協議記憶體塊包含了欄位 Features)。我們可以寫一段程式碼獲取你的資料, 將資料填入到Example協議記憶體塊(protocol buffer),將協議記憶體塊序列化為一個字串, 並且通過tf.python_io.TFRecordWriter 寫入到TFRecords檔案。
從TFRecords檔案中讀取資料, 可以使用tf.TFRecordReader的tf.parse_single_example解析器。這個操作可以將Example協議記憶體塊(protocol buffer)解析為Tensor。
Image to TFRecord
首先我們使用TensorFlow提供的Flowers資料集做這個實驗,資料集在我本地的路徑為:
這是一個五分類的資料,以類別的形式組織資料,這非常符合我們自己組織資料集的習慣。其中一個分類中大概有700張左右的圖片:
現在我們就把上面的資料製作出TFRecord,在這裡需要說明下,TFRecord的生成要注意兩點:
1.很多時候,我們的圖片尺寸並不是統一的,所以在生成的TFRecord中需要包含影象的width和height這兩個資訊,這樣在解析圖片的時候,我們才能把二進位制的資料重新reshape成圖片;
2.TensorFlow官方的建議是一個TFRecord中最好圖片的數量為1000張左右,這個很好理解,如果我們有上萬張圖片,卻只打成一個包,這樣是很不利於多執行緒讀取的。所以我們需要根據影象資料自動去選擇到底打包幾個TFRecord出來。
我們可以用下面的程式碼實現這兩個目的:
import os
import tensorflow as tf
from PIL import Image
#圖片路徑
cwd = 'F:\\flowersdata\\trainimages\\'
#檔案路徑
filepath = 'F:\\flowersdata\\tfrecord\\'
#存放圖片個數
bestnum = 1000
#第幾個圖片
num = 0
#第幾個TFRecord檔案
recordfilenum = 0
#類別
classes=['daisy',
'dandelion',
'roses' ,
'sunflowers',
'tulips']
#tfrecords格式檔名
ftrecordfilename = ("traindata.tfrecords-%.3d" % recordfilenum)
writer= tf.python_io.TFRecordWriter(filepath+ftrecordfilename)
#類別和路徑
for index,name in enumerate(classes):
print(index)
print(name)
class_path=cwd+name+'\\'
for img_name in os.listdir(class_path):
num=num+1
if num>bestnum:
num = 1
recordfilenum = recordfilenum + 1
#tfrecords格式檔名
ftrecordfilename = ("traindata.tfrecords-%.3d" % recordfilenum)
writer= tf.python_io.TFRecordWriter(filepath+ftrecordfilename)
#print('路徑',class_path)
#print('第幾個圖片:',num)
#print('檔案的個數',recordfilenum)
#print('圖片名:',img_name)
img_path = class_path+img_name #每一個圖片的地址
img=Image.open(img_path,'r')
size = img.size
print(size[1],size[0])
print(size)
#print(img.mode)
img_raw=img.tobytes()#將圖片轉化為二進位制格式
example = tf.train.Example(
features=tf.train.Features(feature={
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
'img_width':tf.train.Feature(int64_list=tf.train.Int64List(value=[size[0]])),
'img_height':tf.train.Feature(int64_list=tf.train.Int64List(value=[size[1]]))
}))
writer.write(example.SerializeToString()) #序列化為字串
writer.close()
在上面的程式碼中,我們規定了一個TFRecord中只放1000張圖:
bestnum = 1000
並且將一張圖的4個資訊打包到TFRecord中,分別是:
example = tf.train.Example(
features=tf.train.Features(feature={
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
'img_width':tf.train.Feature(int64_list=tf.train.Int64List(value=[size[0]])),
'img_height':tf.train.Feature(int64_list=tf.train.Int64List(value=[size[1]]))
}))
TFRecord to Image
在上面我們打包了四個TFRecord檔案,下面我們把這些資料讀取並顯示出來,看看製作的效果,這個過程很大一部分是和TensorFlow組織batch是一樣的了。
import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt
#寫入圖片路徑
swd = 'F:\\flowersdata\\show\\'
#TFRecord檔案路徑
data_path = 'F:\\flowersdata\\tfrecord\\traindata.tfrecords-003'
# 獲取檔名列表
data_files = tf.gfile.Glob(data_path)
print(data_files)
# 檔名列表生成器
filename_queue = tf.train.string_input_producer(data_files,shuffle=True)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue) #返回檔名和檔案
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw' : tf.FixedLenFeature([], tf.string),
'img_width': tf.FixedLenFeature([], tf.int64),
'img_height': tf.FixedLenFeature([], tf.int64),
}) #取出包含image和label的feature物件
#tf.decode_raw可以將字串解析成影象對應的畫素陣列
image = tf.decode_raw(features['img_raw'], tf.uint8)
height = tf.cast(features['img_height'],tf.int32)
width = tf.cast(features['img_width'],tf.int32)
label = tf.cast(features['label'], tf.int32)
channel = 3
image = tf.reshape(image, [height,width,channel])
with tf.Session() as sess: #開始一個會話
init_op = tf.initialize_all_variables()
sess.run(init_op)
#啟動多執行緒
coord=tf.train.Coordinator()
threads= tf.train.start_queue_runners(coord=coord)
for i in range(15):
#image_down = np.asarray(image_down.eval(), dtype='uint8')
plt.imshow(image.eval())
plt.show()
single,l = sess.run([image,label])#在會話中取出image和label
img=Image.fromarray(single, 'RGB')#這裡Image是之前提到的
img.save(swd+str(i)+'_''Label_'+str(l)+'.jpg')#存下圖片
#print(single,l)
coord.request_stop()
coord.join(threads)
注意:
1.我們在使用reshape去將二進位制資料重新變成圖片的時候,用的就是之前打包進去的width和height,否則程式會出錯;
image = tf.reshape(image, [height,width,channel])
2.在圖片儲存時的命名方式為:mun_Label_calss id
3.程式碼也可以實時show出當前的圖片:
完整程式碼也可以點選這裡下載。