tensorflow tfrecord 讀寫 例項
阿新 • • 發佈:2018-11-23
import tensorflow as tf tf_file_writer = tf.python_io.TFRecordWriter("tmp") fea = [1,2,3]#tf.ones(shape=[3,2],dtype=tf.float32) label_vector= [4,5,6]#tf.zeros([2],dtype=tf.float32) example = tf.train.Example(features=tf.train.Features(feature={ 'fea': tf.train.Feature(float_list=tf.train.FloatList(value=fea)), 'fea_shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[1,1,1])), 'label': tf.train.Feature(float_list=tf.train.FloatList(value=label_vector)) })) fea2 = [10,20,30] label_vector2= [40,50,60] example2 = tf.train.Example(features=tf.train.Features(feature={ 'fea': tf.train.Feature(float_list=tf.train.FloatList(value=fea2)), 'fea_shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[1,1,1])), 'label': tf.train.Feature(float_list=tf.train.FloatList(value=label_vector2)) })) tf_file_writer.write(example.SerializeToString()) tf_file_writer.write(example2.SerializeToString()) def _parse_function(example_proto): dics = { 'fea': tf.VarLenFeature(dtype=tf.float32), 'fea_shape': tf.VarLenFeature(dtype=tf.int64), 'label': tf.VarLenFeature(dtype=tf.float32)} parsed_example = tf.parse_single_example(example_proto, dics) parsed_example['fea'] = tf.sparse_tensor_to_dense(parsed_example['fea']) parsed_example['label'] = tf.sparse_tensor_to_dense(parsed_example['label']) # parsed_example['fea'] = tf.reshape(parsed_example['fea'], parsed_example['fea_shape']) return parsed_example dataset = tf.data.TFRecordDataset(["tmp"]) new_dataset = dataset.map(_parse_function) shuffle_dataset = new_dataset.shuffle(buffer_size=1024) prefetch_dataset = shuffle_dataset.prefetch(1) iterator = prefetch_dataset.make_one_shot_iterator() next_element = iterator.get_next() batch_fea = next_element['fea'] fea_shape = next_element['fea_shape'] batch_label = next_element['label'] # batch_label = tf.cast(batch_label, tf.int32) sess=tf.Session() print(sess.run([batch_fea,fea_shape,batch_label])) # print(sess.run(fea_shape)) print(sess.run([batch_fea,fea_shape,batch_label]))