1. 程式人生 > >tensorflow讀取tfrecord格式的資料

tensorflow讀取tfrecord格式的資料

1,生成train.tfrecords的資料,gen_data.py

import os
import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

path = r"D:\Deep_Learning_data\cyclegan\apple2orange" # apple(蘋果) testA, orange(橘子) testB
classes = {'testA', 'testB'}
writer = tf.python_io.TFRecordWriter('train.tfrecords') # 要生成的檔案

for index, name in enumerate(classes):
    class_path = path + "\\" + name + "\\"
    for img_name in os.listdir(class_path):
        img_path = os.path.join(class_path, img_name)
        img = Image.open(img_path)
        img = img.resize((256,256))
        img_raw = img.tobytes() # 將圖片轉化為二進位制格式
        example = tf.train.Example(features=tf.train.Features(feature={
            'label':tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
            'img_raw':tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
        }))
    writer.write(example.SerializeToString())
writer.close()

def read_and_decode(filename): # 讀取tfrecords資料
    filename_queue = tf.train.string_input_producer([filename])
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example,
            features={
                    'label':tf.FixedLenFeature([], tf.int64),
                    'img_raw':tf.FixedLenFeature([], tf.string)
    }) # 將image資料和label取出來
    img = tf.decode_raw(features['img_raw'], tf.uint8)
    img = tf.reshape(img, [256,256,3])
    img = tf.cast(img, tf.float32)
    label = tf.cast(features['label', tf.int32])
    return img, label

2,讀取tfrecord格式的資料,read_data.py

import os
import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

file_queue = tf.train.string_input_producer(['train.tfrecords'])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(file_queue)
features = tf.parse_single_example(serialized_example,
        features={
            'label':tf.FixedLenFeature([], tf.int64),
            'img_raw':tf.FixedLenFeature([], tf.string)
        })
image = tf.decode_raw(features['img_raw'], tf.uint8)
image = tf.reshape(image, [256,256,3])
label = tf.cast(features['label'], tf.int32)
with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    for i in range(10):
        print()
        example, l = sess.run([image, label])
        img = Image.fromarray(example, 'RGB')
        save_path = os.getcwd() +"\\" + str(i) + '_''Label_' + str(l) + ".jpg"
        print(save_path)
        img.save(save_path)
        print(example.shape,l.shape)
        coord.join(threads)