1. 程式人生 > >tensorflow tfrecord 讀寫 例項

tensorflow tfrecord 讀寫 例項

import tensorflow as tf
tf_file_writer = tf.python_io.TFRecordWriter("tmp")
fea = [1,2,3]#tf.ones(shape=[3,2],dtype=tf.float32)
label_vector= [4,5,6]#tf.zeros([2],dtype=tf.float32)
example = tf.train.Example(features=tf.train.Features(feature={
    'fea': tf.train.Feature(float_list=tf.train.FloatList(value=fea)),
    'fea_shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[1,1,1])),
    'label': tf.train.Feature(float_list=tf.train.FloatList(value=label_vector))
}))
fea2 = [10,20,30]
label_vector2= [40,50,60]
example2 = tf.train.Example(features=tf.train.Features(feature={
    'fea': tf.train.Feature(float_list=tf.train.FloatList(value=fea2)),
    'fea_shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[1,1,1])),
    'label': tf.train.Feature(float_list=tf.train.FloatList(value=label_vector2))
}))
tf_file_writer.write(example.SerializeToString())
tf_file_writer.write(example2.SerializeToString())

def _parse_function(example_proto):
    dics = {
        'fea': tf.VarLenFeature(dtype=tf.float32),
        'fea_shape': tf.VarLenFeature(dtype=tf.int64),
        'label': tf.VarLenFeature(dtype=tf.float32)}

    parsed_example = tf.parse_single_example(example_proto, dics)
    parsed_example['fea'] = tf.sparse_tensor_to_dense(parsed_example['fea'])
    parsed_example['label'] = tf.sparse_tensor_to_dense(parsed_example['label'])
    # parsed_example['fea'] = tf.reshape(parsed_example['fea'], parsed_example['fea_shape'])
    return parsed_example

dataset = tf.data.TFRecordDataset(["tmp"])
new_dataset = dataset.map(_parse_function)
shuffle_dataset = new_dataset.shuffle(buffer_size=1024)
prefetch_dataset = shuffle_dataset.prefetch(1)
iterator = prefetch_dataset.make_one_shot_iterator()
next_element = iterator.get_next()
batch_fea = next_element['fea']
fea_shape = next_element['fea_shape']
batch_label = next_element['label']
# batch_label = tf.cast(batch_label, tf.int32)
sess=tf.Session()
print(sess.run([batch_fea,fea_shape,batch_label]))
# print(sess.run(fea_shape))
print(sess.run([batch_fea,fea_shape,batch_label]))