1. 程式人生 > >tensorflow將圖片儲存為tfrecord和tfrecord的讀取

tensorflow將圖片儲存為tfrecord和tfrecord的讀取

tensorflow官方提供了3種方法來讀取資料:

  • 預載入資料(preloaded data):在TensorFlow圖中定義常量或變數來儲存所有的資料,適用於資料量不太大的情況。
  • 填充資料(feeding):通過Python產生資料,然後再把資料填充到後端。
  • 從檔案讀取資料(reading from file):從檔案中直接讀取,然後通過佇列管理器從檔案中讀取資料。

本文主要介紹第三種方法,通過tfrecord檔案來儲存和讀取資料,對於前兩種讀取資料的方式也會進行一個簡單的介紹。

專案下載github地址:https://github.com/steelOneself/tensorflow_learn/tree/master/tf_records_writer_read

一、預載入資料

    a = tf.constant([1,2,3])
    b = tf.constant([4,5,6])
    c = tf.add(a,b)
    with tf.Session() as sess:
        print(sess.run(c))#[5 7 9]

這種方式載入資料比較簡單,它是直接將資料嵌入在資料流圖中,當訓練資料較大時,比較消耗記憶體。

二、填充資料

通過先定義placeholder然後再通過feed_dict來餵養資料,這種方式在TensorFlow中使用的也是比較多的,但是也存在資料量大時比較消耗記憶體的缺點,下面介紹一種更高效的資料讀取方式,通過tfrecord檔案來讀取資料。

    x = tf.placeholder(tf.int16)
    y = tf.placeholder(tf.int16)
    z = tf.add(x,y)
    with tf.Session() as sess:
        print(sess.run(z,feed_dict={x:[1,2,3],y:[4,5,6]}))
        #[5 7 9]

三、從檔案讀取資料

通過slim來實現將圖片儲存為tfrecord檔案和tfrecord檔案的讀取,slim是基於TensorFlow的一個更高級別的封裝模型,通過slim來程式設計可以實現更高效率和更簡潔的程式碼。

在本次實驗中使用的資料集是kaggle的dog vs cat,資料集下載地址:https://www.kaggle.com/c/dogs-vs-cats/data

1、tfrecord檔案的儲存

a、引數設定

  • dataset_dir_path:訓練集圖片存放的上級目錄(train下還有一個train目錄用來存放圖片),在dog vs cat資料集中,dog和cat類的區別是依靠圖片的名稱,如果你的資料集通過資料夾的名稱來劃分圖片類標的,可能需要對程式碼進行部分修改。
  • label_name_to_num:字串類標與數字類標的對應關係,在將圖片儲存為tfrecord檔案的時候,需要將字串轉為整數類標0和1,方便後的訓練。
  • label_num_to_name:數字類標與字串類標的對應關係。
  • val_size:驗證集在訓練集中所佔的比例,訓練集一共有25000張圖片,用20000張來訓練,5000張來進行驗證。
  • batch_size:在讀取tfrecord檔案的時候,每次讀取圖片的數量。
#資料所在的目錄路徑
dataset_dir_path = "D:/dataset/kaggle/cat_or_dog/train"
#類標名稱和數字的對應關係
label_name_to_num = {"cat":0,"dog":1}
label_num_to_name = {value:key for key,value in label_name_to_num.items()}
#設定驗證集佔整個資料集的比例
val_size = 0.2
batch_size = 1

b、獲取訓練集所有的圖片路徑

獲取訓練目錄下所有的dog和cat的圖片路徑,將它們分開儲存,便於後面訓練集和驗證集資料的劃分,保證每類圖片在所佔的比例相同。


  #獲取檔案所在路徑
  dataset_dir = os.path.join(dataset_dir,split_name)
  #遍歷目錄下的所有圖片
  for filename in os.listdir(dataset_dir):
      #獲取檔案的路徑
      file_path = os.path.join(dataset_dir,filename)
      if file_path.endswith("jpg") and os.path.exists(file_path):
          #獲取類別的名稱
          label_name = filename.split(".")[0]
          if label_name == "cat":
              cat_img_paths.append(file_path)
          elif label_name == "dog":
              dog_img_paths.append(file_path)
  return cat_img_paths,dog_img_paths

c、設定需要儲存的圖片資訊

對於訓練集的圖片主要儲存圖片的位元組資料、圖片的格式、圖片的標籤、圖片的高和寬,測試集儲存為tfrecord檔案的時候需要儲存圖片的名稱,因為在提交資料的時候需要用到圖片的名稱資訊。在儲存圖片資訊的時候,需要先將這些資訊轉換為byte資料才能寫入到tfrecord檔案中。

def int64_feature(values):
  if not isinstance(values, (tuple, list)):
    values = [values]
  return tf.train.Feature(int64_list=tf.train.Int64List(value=values))

def bytes_feature(values):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))


#將圖片資訊轉換為tfrecords可以儲存的序列化資訊
def image_to_tfexample(split_name,image_data, image_format, height, width, img_info):
    '''
    :param split_name: train或val或test
    :param image_data: 圖片的二進位制資料
    :param image_format: 圖片的格式
    :param height: 圖片的高
    :param width: 圖片的寬
    :param img_info: 圖片的標籤或圖片的名稱,當split_name為test時,img_info為圖片的名稱否則為圖片標籤
    :return:
    '''
    if split_name == "test":
        return tf.train.Example(features=tf.train.Features(feature={
              'image/encoded': bytes_feature(image_data),
              'image/format': bytes_feature(image_format),
              'image/img_name': bytes_feature(img_info),
              'image/height': int64_feature(height),
              'image/width': int64_feature(width),
          }))
    else:
          return tf.train.Example(features=tf.train.Features(feature={
              'image/encoded': bytes_feature(image_data),
              'image/format': bytes_feature(image_format),
              'image/label': int64_feature(img_info),
              'image/height': int64_feature(height),
              'image/width': int64_feature(width),
          }))

d、儲存tfrecord檔案

主要是通過TFRecordWriter來儲存tfrecord檔案,在將圖片資訊儲存為tfrecord檔案的時候,需要先將圖片資訊序列化為字串才能進行寫入。ImageReader類可以將圖片位元組資料解碼為指定格式的圖片,獲取圖片的寬和高資訊。_get_dataset_filename函式是通過資料集的名稱和split_name的名稱來組合獲取tfrecord檔案的名稱,tfrecord名稱如下:

def _convert_tfrecord_dataset(split_name, filenames, label_name_to_id, 
dataset_dir, tfrecord_filename, _NUM_SHARDS):
    '''
    :param split_name:train或val或test
    :param filenames:圖片的路徑列表
    :param label_name_to_id:標籤名與數字標籤的對應關係
    :param dataset_dir:資料存放的目錄
    :param tfrecord_filename:檔案儲存的字首名
    :param _NUM_SHARDS:將整個資料集分為幾個檔案
    :return:
    '''
    assert split_name in ['train', 'val','test']
    #計算平均每一個tfrecords檔案儲存多少張圖片
    num_per_shard = int(math.ceil(len(filenames) / float(_NUM_SHARDS)))
    with tf.Graph().as_default():
        image_reader = ImageReader()
        with tf.Session('') as sess:
            for shard_id in range(_NUM_SHARDS):
                #獲取tfrecord檔案的名稱
                output_filename = _get_dataset_filename(
                       dataset_dir, split_name, shard_id,
 tfrecord_filename = tfrecord_filename, _NUM_SHARDS = _NUM_SHARDS)
                #寫tfrecords檔案
                with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
                    start_ndx = shard_id * num_per_shard
                    end_ndx = min((shard_id+1) * num_per_shard, len(filenames))
                    for i in range(start_ndx, end_ndx):
                        #更新控制檯中已經完成的圖片數量
                        sys.stdout.write('\r>> Converting image %d/%d shard %d' % (
                            i+1, len(filenames), shard_id))
                        sys.stdout.flush()
                        #讀取圖片,將圖片資料讀取為bytes
                        image_data = tf.gfile.FastGFile(filenames[i], 'rb').read()
                        #獲取圖片的高和寬
                        height, width = image_reader.read_image_dims(sess, image_data)
                        #獲取路徑中的圖片名稱
                        img_name = os.path.basename(filenames[i])
                        if split_name == "test":
                            #需要將圖片名稱轉換為二進位制
                            example = image_to_tfexample(
                                split_name,image_data, b'jpg', height, width, img_name.encode())
                            tfrecord_writer.write(example.SerializeToString())
                        else:
                            #獲取圖片的類別
                            class_name = img_name.split(".")[0]
                            label_id = label_name_to_id[class_name]
                            example = image_to_tfexample(
                                split_name,image_data, b'jpg', height, width, label_id)
                            tfrecord_writer.write(example.SerializeToString())
                sys.stdout.write('\n')
                sys.stdout.flush()

e、將資料集分為驗證集和訓練集儲存為tfrecord檔案

先獲取資料集中所有圖片的路徑和圖片的標籤資訊,將不同類別的圖片分為訓練集和驗證集,並保證訓練集和驗證集中不同類別的圖片數量保持相同,在儲存為tfrecord檔案之前,打亂所有圖片的路徑。將訓練集分為了2個tfrecord檔案,驗證集儲存為1個tfrecord檔案。

#生成tfrecord檔案
def generate_tfreocrd():
    #獲取目錄下所有的貓和狗圖片的路徑
    cat_img_paths,dog_img_paths = _get_dateset_imgPaths(dataset_dir_path,"train")
    #打亂路徑列表的順序
    np.random.shuffle(cat_img_paths)
    np.random.shuffle(dog_img_paths)
    #計算不同類別驗證集所佔的圖片數量
    cat_val_num = int(len(cat_img_paths) * val_size)
    dog_val_num = int(len(dog_img_paths) * val_size)
    #將所有的圖片路徑分為訓練集和驗證集
    train_img_paths = cat_img_paths[cat_val_num:]
    val_img_paths = cat_img_paths[:cat_val_num]
    train_img_paths.extend(dog_img_paths[dog_val_num:])
    val_img_paths.extend(dog_img_paths[:dog_val_num])
    #打亂訓練集和驗證集的順序
    np.random.shuffle(train_img_paths)
    np.random.shuffle(val_img_paths)
    #將訓練集儲存為tfrecord檔案
    _convert_tfrecord_dataset("train",train_img_paths,label_name_to_num,dataset_dir_path,"catVSdog",2)
    #將驗證集儲存為tfrecord檔案
    _convert_tfrecord_dataset("val",val_img_paths,label_name_to_num,dataset_dir_path,"catVSdog",1)

通過控制檯你能夠看到tfrecord檔案的儲存進度

2、從tfrecord檔案中讀取資料

a、讀取tfrecord檔案,將資料轉換為dataset

通過TFRecordReader來讀取tfrecord檔案,在讀取tfrecord檔案時需要通過tf.FixedLenFeature來反序列化儲存的圖片資訊,這裡我們只讀取圖片資料和圖片的標籤,再通過slim模組將圖片資料和標籤資訊儲存為一個dataset。

    #建立一個tfrecord讀檔案物件
    reader = tf.TFRecordReader
        keys_to_feature = {
            "image/encoded":tf.FixedLenFeature((),tf.string,default_value=""),
            "image/format":tf.FixedLenFeature((),tf.string,default_value="jpg"),
         "image/label":tf.FixedLenFeature([],tf.int64,default_value=tf.zeros([],tf.int64))
        }
        items_to_handles = {
            "image":slim.tfexample_decoder.Image(),
            "label":slim.tfexample_decoder.Tensor("image/label")
        }
        items_to_descriptions = {
            "image":"a 3-channel RGB image",
            "img_name":"a image label"
        }
        #建立一個tfrecoder解析物件
        decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_feature,items_to_handles)
        #讀取所有的tfrecord檔案,建立資料集
        dataset = slim.dataset.Dataset(
            data_sources = tfrecord_paths,
            decoder = decoder,
            reader = reader,
            num_readers = 4,
            num_samples = num_imgs,
            num_classes = num_classes,
            labels_to_name = labels_to_name,
            items_to_descriptions = items_to_descriptions
        )

b、獲取batch資料

preprocessing_image對圖片進行預處理,對圖片進行資料增強,輸出後的圖片尺寸由height和width引數決定,固定圖片的尺寸方便CNN的模型訓練。

def load_batch(split_name,dataset,batch_size,height,width):
    data_provider = slim.dataset_data_provider.DatasetDataProvider(
        dataset,
        common_queue_capacity = 24 + 3 * batch_size,
        common_queue_min = 24
    )
        raw_image,img_label = data_provider.get(["image","label"])
        #Perform the correct preprocessing for this image depending if it is training or evaluating
        image = preprocess_image(raw_image, height, width,True)
        #As for the raw images, we just do a simple reshape to batch it up
        raw_image = tf.expand_dims(raw_image, 0)
        raw_image = tf.image.resize_nearest_neighbor(raw_image, [height, width])
        raw_image = tf.squeeze(raw_image)
        #獲取一個batch資料
        images,raw_image,labels = tf.train.batch(
            [image,raw_image,img_label],
            batch_size=batch_size,
            num_threads=4,
            capacity=4*batch_size,
            allow_smaller_final_batch=True
        )
        return images,raw_image,labels

c、讀取tfrecord檔案

#讀取tfrecord檔案
def read_tfrecord():
    #從tfreocrd檔案中讀取資料
    train_dataset = get_dataset_by_tfrecords("train",dataset_dir_path,"catVSdog",2,label_num_to_name)
    images,raw_images,labels = load_batch("train",train_dataset,batch_size,227,227)
    with tf.Session() as sess:
        threads = tf.train.start_queue_runners(sess)
        for i in range(6):
            train_img,train_label = sess.run([raw_images,labels])
            plt.subplot(2,3,i+1)
            plt.imshow(np.array(train_img[0]))
            plt.title("image label:%s"%str(label_num_to_name[train_label[0]]))
        plt.show()

讀取訓練集的tfrecord檔案,只從tfrecord檔案中獲取了圖片資料和圖片的標籤,images表示的是預處理後的圖片,raw_images表示的是沒有經過預處理的圖片。