TFRecords的建立和讀取——自定義圖片資料的製作
阿新 • • 發佈:2019-01-29
TFRecords檔案的建立和讀取
首先是檔案結構,這是我的檔案結構,大家可以自定義檔案位置,但是結構應該如下:首先是主資料夾tensorflow_application/jpg,該資料夾下有兩個次資料夾001和002,001資料夾的下面是一類圖片;002資料夾的下面是另一類圖片。本文以資料夾的名稱作為每一類圖片的名稱,這在影象識別的影象預處理中是比較常用的,比較方便。
由於這不是專門的程式碼編輯器,所以其中的有些縮排可能不規範,所以儘量不要直接複製貼上。可以自己按照程式碼在編輯器中編寫。
---tensorflow_application/jpg
---001
---0012.jpg
---...
---002
---0000.jpg
---...
然後是TFRecords檔案的建立。
# 首先是模組的匯入
"""
os模組是處理資料夾用的
PIL模組是用來處理圖片的
"""
import tensorflow as tf
import os
from PIL import Image
path = "tensorflow_application/jpg" # 這是上述檔案結構的主資料夾路徑
filename = os.listdir(path) # 作用是遍歷path資料夾下的檔案,返回的是001和002資料夾構成的一個列表
writer = tf.python_io.TFRecordWriter("tensorflow_application/train.tfrecords" ) # 將TFRecordWriter例項化,用於檔案的寫操作。其中的路徑是tfrecords檔案的存放路徑,這個路徑並不需要實現建立,程式碼會自動生成
for name in filename:
class_path = path + os.sep + name # 得到每一類的路徑,即001資料夾和002資料夾的路徑,其中的os.sep返回的是一個符號,即'//',這是路徑中的一個符號而已,起到連線作用,構成此資料夾的完整路徑
for img_name in os.listdir(class_path):
img_path = class_path + os.sep + img_name # 同上,得到此資料夾下的每一張圖片的完整路徑,用於後續的圖片提取並處理
img = Image.open(img_path) # 取出圖片
img = img.resize((500, 500)) # 改變圖片大小,大小視具體的網路要求而定,不同的網路對輸入圖片的大小並不完全相同。這裡我暫且將圖片變為500*500的大小
img_raw = img.tobytes() # 這裡將圖片矩陣變為字串形式進行儲存,因為TFRecords能夠儲存的只能是二進位制資料,因此需要將陣列轉換為二進位制形式
# 下面是關鍵的步驟,將資料填入到Example協議記憶體塊中,最終生成TFRecords檔案。TFRecords檔案就是通過一個包含著二進位制檔案的資料檔案,將特徵和標籤進行儲存便於TensorFlow讀取
"""
一個tf.train.Example,即Example協議記憶體塊,包含著若干資料特徵(Features),而Features
中又包含著Feature字典。任何一個Feature中又包含著FloatList, Int64List或BytesList,本例
中使用到了其中兩種資料格式,即Int64List和BytesList,需要注意的是value後跟的值需要為
列表形式,所以加上了方括號
"""
example = tf.train.Example(
features = tf.train.Features(
feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[name])),
"image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))))
}
serialized = example.SerializeToString() # 先將樣本進行序列化操作
writer.write(serialized) # 對序列化操作後的變數進行寫操作,即生成最終的tfrecords檔案
接下來需要做的便是讀取生成的tfrecords檔案,在神經網路中,需要將tfrecords檔案中的image和label讀取出來,然後將其傳遞給圖。
# 使用的模組還是tensorflow
import tensorflow as tf
filename = "tensorflow_application/train.tfrecords" # 這是上面生成的tfrecords檔案
filename_queue = tf.train.string_input_producer([filenname]) # 建立一個佇列,其中的引數為tfrecords檔案的路徑
reader = tf.TFRecordReader() # 例項化讀操作,建立讀取器
_, serialized_example = reader.read(filename_queue) # 返回檔名和檔案
"""
通過parse_single_example解析器解析,將Example協議記憶體塊解析為張量(Tensor),然後使用
解碼器tf.decode_raw解碼
"""
features = tf.parse_single_example(serialized_example,
features={
"label": tf.FixedLenFeature([], tf.int64),
"image": tf.FixedLenFeature([], tf.string)
})
img = tf.decode_raw(features["image"], tf.uint8) # 使用tf.decode_raw解碼
img = tf.reshape(img, [500, 500, 3]) # 重構圖片的大小為500*500*3
img = tf.cast(img, tf.float32) * (1. / 128) - 0.5
label = tf.cast(features["label"], tf.int32)
"""
上面將img和label從tfrecords檔案中讀取了出來,但是如果需要將資料取出供
圖使用,還需要使用tf.train.shuffle_batch
shuffle_batch的主要引數為:
1. tensor: 入隊佇列,即上面得到的img和label,[img, label]
2. batch_size: batch的大小
3. capacity: 佇列的最大容量
4. num_threads: 執行緒數
5. min_after_dequeue: 限制出隊時佇列中元素的最小個數
"""
img_batch, label_batch = tf.train.shuffle_batch([img, label], batch_size=1,
capacity=24, min_after_dequeue=1) # 將得到的img_batch, label_batch傳遞給需要進行遞迴的資料即可