tensorflow tfrecord檔案存取
阿新 • • 發佈:2018-11-10
import tensorflow as tf import numpy as np import skimage from skimage import data, io, color path = "1.tfrecords" with tf.python_io.TFRecordWriter(path) as writer: a = 1024 b = 10.24 c = [0.1, 0.2, 0.3] d = [[1, 2], [3, 4]] e = "Python" img = io.imread('/data/test/img/a.jpg') img_shape = img.shape c = np.array(c).astype(np.float32).tobytes() d = np.array(d).astype(np.int8).tobytes() e = bytes(e, encoding='utf-8') img = img.astype(np.int16).tobytes() example = tf.train.Example(features=tf.train.Features(feature={ 'a': tf.train.Feature(int64_list=tf.train.Int64List(value=[a])),'b': tf.train.Feature(float_list=tf.train.FloatList(value=[b])), 'c': tf.train.Feature(bytes_list=tf.train.BytesList(value=[c])), 'd': tf.train.Feature(bytes_list=tf.train.BytesList(value=[d])), 'e': tf.train.Feature(bytes_list=tf.train.BytesList(value=[e])), 'img': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img])), })) writer.write(example.SerializeToString())# 讀取 filename_queue = tf.train.string_input_producer([path]) _, serialized_example = tf.TFRecordReader().read(filename_queue) features = tf.parse_single_example(serialized_example, features={ 'a': tf.FixedLenFeature([], tf.int64), 'b': tf.FixedLenFeature([], tf.float32), 'c': tf.FixedLenFeature([], tf.string), 'd': tf.FixedLenFeature([], tf.string), 'e': tf.FixedLenFeature([], tf.string), 'img': tf.FixedLenFeature([], tf.string), }) a = features['a'] # 返回是張量 b = features['b'] c = features['c'] c = tf.decode_raw(c, tf.float32) d = features['d'] d = tf.decode_raw(d, tf.int8) d = tf.reshape(d, [2, 2]) e = features['e'] img = features['img'] img = tf.decode_raw(img, tf.int16) img = tf.reshape(img, shape=img_shape) with tf.Session() as sess: sess.run(tf.initialize_all_variables()) tf.train.start_queue_runners(sess=sess) print(sess.run([a, b, c, d, e])) e = sess.run(e) print(type(e), bytes.decode(e)) print(sess.run(img))