tensorflow讀取tfrecord格式的資料
阿新 • • 發佈:2018-11-14
1,生成train.tfrecords的資料,gen_data.py
import os import tensorflow as tf from PIL import Image import matplotlib.pyplot as plt import numpy as np path = r"D:\Deep_Learning_data\cyclegan\apple2orange" # apple(蘋果) testA, orange(橘子) testB classes = {'testA', 'testB'} writer = tf.python_io.TFRecordWriter('train.tfrecords') # 要生成的檔案 for index, name in enumerate(classes): class_path = path + "\\" + name + "\\" for img_name in os.listdir(class_path): img_path = os.path.join(class_path, img_name) img = Image.open(img_path) img = img.resize((256,256)) img_raw = img.tobytes() # 將圖片轉化為二進位制格式 example = tf.train.Example(features=tf.train.Features(feature={ 'label':tf.train.Feature(int64_list=tf.train.Int64List(value=[index])), 'img_raw':tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])) })) writer.write(example.SerializeToString()) writer.close() def read_and_decode(filename): # 讀取tfrecords資料 filename_queue = tf.train.string_input_producer([filename]) reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) features = tf.parse_single_example(serialized_example, features={ 'label':tf.FixedLenFeature([], tf.int64), 'img_raw':tf.FixedLenFeature([], tf.string) }) # 將image資料和label取出來 img = tf.decode_raw(features['img_raw'], tf.uint8) img = tf.reshape(img, [256,256,3]) img = tf.cast(img, tf.float32) label = tf.cast(features['label', tf.int32]) return img, label
2,讀取tfrecord格式的資料,read_data.py
import os import tensorflow as tf from PIL import Image import matplotlib.pyplot as plt import numpy as np file_queue = tf.train.string_input_producer(['train.tfrecords']) reader = tf.TFRecordReader() _, serialized_example = reader.read(file_queue) features = tf.parse_single_example(serialized_example, features={ 'label':tf.FixedLenFeature([], tf.int64), 'img_raw':tf.FixedLenFeature([], tf.string) }) image = tf.decode_raw(features['img_raw'], tf.uint8) image = tf.reshape(image, [256,256,3]) label = tf.cast(features['label'], tf.int32) with tf.Session() as sess: sess.run(tf.initialize_all_variables()) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) for i in range(10): print() example, l = sess.run([image, label]) img = Image.fromarray(example, 'RGB') save_path = os.getcwd() +"\\" + str(i) + '_''Label_' + str(l) + ".jpg" print(save_path) img.save(save_path) print(example.shape,l.shape) coord.join(threads)