1. 程式人生 > >TFrecord:write&read

TFrecord:write&read

參考了這位仁兄的部落格

概述

在訓練卷積神經網路時,將圖片提前處理好並快取在磁碟上,通過中間檔案隨機呼叫訪問可以明顯提高訓練速度,並且可以減少重複處理圖片的工作。

write

通過tf.train.Example Protocol Buffer
下面程式碼源於本人寫的一個函式

def create_tfrecord(result, sess):
    """
    create tfrecord files for train,validation,test
    Args:
        result: the dictionary of images
        sess: the session

    """
path = FLAGS.tfrecord_dir if not tf.gfile.Exists(path): tf.gfile.MakeDirs(path) tf_filename = os.path.join(path,'validation.tfrecord') jpeg_data_tensor, decoded_image_tensor = add_jpeg_decoding() writer = tf.python_io.TFRecordWriter(tf_filename) #print(len(result['validation']))
for index_val,file in enumerate(result['validation']): tf.logging.info("write the %d in validation"%index_val) name,_ = os.path.splitext(file) label= get_labels_array(name + '.txt') input_image_array = create_input_tensor(file, sess, jpeg_data_tensor, decoded_image_tensor) input_image_string = input_image_array.tostring() label_string = label.tostring() example = tf.train.Example(features = tf.train.Features( feature = { 'label'
: tf.train.Feature(bytes_list = tf.train.BytesList(value = [label_string])), 'image': tf.train.Feature(bytes_list = tf.train.BytesList(value = [input_image_string])) })) writer.write(example.SerializeToString()) writer.close()

read

讀比較麻煩,還要建立執行緒什麼的
注意函式使用了多執行緒

def read_tfrecord(file_name,batch):
    filename_queue = tf.train.string_input_producer([file_name],)
    reader = tf.TFRecordReader()
    _, serialize_example = reader.read(filename_queue)
    feature = tf.parse_single_example(serialize_example,
                                       features = {
                                               'label': tf.FixedLenFeature([], tf.string),
                                               'image': tf.FixedLenFeature([], tf.string),
                                               })
    labels = tf.decode_raw(feature['label'],tf.int64)
    labels = tf.reshape(labels, [26])
    images = tf.decode_raw(feature['image'],tf.float32)
    images = tf.reshape(images, [1080, 1440, 3])
    #coord = tf.train.Coordinator()
    #threads = tf.train.start_queue_runners(sess = sess,coord = coord)
    #images = tf.squeeze(images)
    images = tf.image.convert_image_dtype(images,tf.int8)
    if batch > 1:
        images, labels = tf.train.shuffle_batch([images,labels],
                                                batch_size=batch,
                                                capacity=500,
                                                num_threads=2,
                                                min_after_dequeue=10)

    return images,labels

def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    """
    result = create_image_lists(FLAGS.image_dir,FLAGS.test_dir,30)
    label = get_labels_path(result['testing'])
    """
    #label = get_labels_array(r'G:\GraduateStudy\Smoke Recognition\Newdata\Train\10830004.txt')
    #result = create_image_lists(FLAGS.image_dir, FLAGS.test_dir, 10)
    file_name = r'G:\GraduateStudy\Smoke Recognition\Newdata\Tfrecord\validation.tfrecord'
    image,label = read_tfrecord(file_name,8)
    with tf.Session() as sess:
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        #create_tfrecord(result,sess)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord = coord)
        try:
            for i in range(2):
                img,labe = sess.run([image,label])
                #cv2.imwrite('image' + str(i) + '.jpg',img)             
                print(img.shape, labe.shape)
        except tf.errors.OutOfRangeError:
            print('Done reading')
        finally:
            coord.request_stop()

        coord.join(threads)