TFrecord:write&read
阿新 • • 發佈:2018-11-26
概述
在訓練卷積神經網路時,將圖片提前處理好並快取在磁碟上,通過中間檔案隨機呼叫訪問可以明顯提高訓練速度,並且可以減少重複處理圖片的工作。
write
通過tf.train.Example Protocol Buffer
下面程式碼源於本人寫的一個函式
def create_tfrecord(result, sess):
"""
create tfrecord files for train,validation,test
Args:
result: the dictionary of images
sess: the session
"""
path = FLAGS.tfrecord_dir
if not tf.gfile.Exists(path):
tf.gfile.MakeDirs(path)
tf_filename = os.path.join(path,'validation.tfrecord')
jpeg_data_tensor, decoded_image_tensor = add_jpeg_decoding()
writer = tf.python_io.TFRecordWriter(tf_filename)
#print(len(result['validation']))
for index_val,file in enumerate(result['validation']):
tf.logging.info("write the %d in validation"%index_val)
name,_ = os.path.splitext(file)
label= get_labels_array(name + '.txt')
input_image_array = create_input_tensor(file, sess, jpeg_data_tensor, decoded_image_tensor)
input_image_string = input_image_array.tostring()
label_string = label.tostring()
example = tf.train.Example(features = tf.train.Features(
feature = {
'label' : tf.train.Feature(bytes_list = tf.train.BytesList(value = [label_string])),
'image': tf.train.Feature(bytes_list = tf.train.BytesList(value = [input_image_string]))
}))
writer.write(example.SerializeToString())
writer.close()
read
讀比較麻煩,還要建立執行緒什麼的
注意函式使用了多執行緒
def read_tfrecord(file_name,batch):
filename_queue = tf.train.string_input_producer([file_name],)
reader = tf.TFRecordReader()
_, serialize_example = reader.read(filename_queue)
feature = tf.parse_single_example(serialize_example,
features = {
'label': tf.FixedLenFeature([], tf.string),
'image': tf.FixedLenFeature([], tf.string),
})
labels = tf.decode_raw(feature['label'],tf.int64)
labels = tf.reshape(labels, [26])
images = tf.decode_raw(feature['image'],tf.float32)
images = tf.reshape(images, [1080, 1440, 3])
#coord = tf.train.Coordinator()
#threads = tf.train.start_queue_runners(sess = sess,coord = coord)
#images = tf.squeeze(images)
images = tf.image.convert_image_dtype(images,tf.int8)
if batch > 1:
images, labels = tf.train.shuffle_batch([images,labels],
batch_size=batch,
capacity=500,
num_threads=2,
min_after_dequeue=10)
return images,labels
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
"""
result = create_image_lists(FLAGS.image_dir,FLAGS.test_dir,30)
label = get_labels_path(result['testing'])
"""
#label = get_labels_array(r'G:\GraduateStudy\Smoke Recognition\Newdata\Train\10830004.txt')
#result = create_image_lists(FLAGS.image_dir, FLAGS.test_dir, 10)
file_name = r'G:\GraduateStudy\Smoke Recognition\Newdata\Tfrecord\validation.tfrecord'
image,label = read_tfrecord(file_name,8)
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
#create_tfrecord(result,sess)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord = coord)
try:
for i in range(2):
img,labe = sess.run([image,label])
#cv2.imwrite('image' + str(i) + '.jpg',img)
print(img.shape, labe.shape)
except tf.errors.OutOfRangeError:
print('Done reading')
finally:
coord.request_stop()
coord.join(threads)