1. 程式人生 > >Tensorflow學習筆記-通過slim讀取TFRecord檔案

Tensorflow學習筆記-通過slim讀取TFRecord檔案

  TFRecord檔案格式的介紹:http://blog.csdn.net/lovelyaiq/article/details/78711944
  由於slim是tensorflow的高階API,使用起來比較方便,例如在卷積或全連線層的書寫時,可以大大減少程式碼量。使用slim讀取TFRecord檔案與tensorflow直接讀取還是有很大的卻別。
  本文就以slim中的例子的flowers來說明。tfrecord中的格式定義為:

image_data = image_data = tf.gfile.FastGFile('img_path', 'rb').read()
def image_to_tfexample
(image_data, image_format, height, width, class_id):
return tf.train.Example(features=tf.train.Features(feature={ 'image/encoded': bytes_feature(image_data), 'image/format': bytes_feature(image_format), 'image/class/label': int64_feature(class_id), 'image/height': int64_feature(height), 'image/width'
: int64_feature(width), }))

原始影象經過處理後,生成5個檔案。flowers_train_00000-of-00005.tfrecord到flowers_train_00004-of-00005.tfrecord。
訓練時,就要通過slim從這5個檔案中讀取資料,然後組合成batch。程式碼如下:

  # 第一步
  # 將example反序列化成儲存之前的格式。由tf完成
  keys_to_features = {
      'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
      'image/format'
: tf.FixedLenFeature((), tf.string, default_value='png'), 'image/class/label': tf.FixedLenFeature( [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)), } # 第一步 # 將反序列化的資料組裝成更高階的格式。由slim完成 items_to_handlers = { 'image': slim.tfexample_decoder.Image('image/encoded','image/format'), 'label': slim.tfexample_decoder.Tensor('image/class/label'), } # 解碼器,進行解碼 decoder = slim.tfexample_decoder.TFExampleDecoder( keys_to_features, items_to_handlers) # dataset物件定義了資料集的檔案位置,解碼方式等元資訊 dataset = slim.dataset.Dataset( data_sources=file_pattern, reader=tf.TFRecordReader, decoder=decoder, num_samples=SPLITS_TO_SIZES[split_name],#訓練資料的總數 items_to_descriptions=_ITEMS_TO_DESCRIPTIONS, num_classes=_NUM_CLASSES, labels_to_names=labels_to_names #字典形式,格式為:id:class_call, ) # provider物件根據dataset資訊讀取資料 provider = slim.dataset_data_provider.DatasetDataProvider( dataset, num_readers=FLAGS.num_readers, common_queue_capacity=20 * FLAGS.batch_size, common_queue_min=10 * FLAGS.batch_size) # 獲取資料,獲取到的資料是單個數據,還需要對資料進行預處理,組合資料 [image, label] = provider.get(['image', 'label']) # 影象預處理 image = image_preprocessing_fn(image, train_image_size, train_image_size) images, labels = tf.train.batch( [image, label], batch_size=FLAGS.batch_size, num_threads=FLAGS.num_preprocessing_threads, capacity=5 * FLAGS.batch_size) labels = slim.one_hot_encoding( labels, dataset.num_classes - FLAGS.labels_offset) batch_queue = slim.prefetch_queue.prefetch_queue( [images, labels], capacity=2 * deploy_config.num_clones) # 組好後的資料 images, labels = batch_queue.dequeue()

  至此,就可以使用images作為神經網路的輸入,使用labels計算損失函式等操作。