1. 程式人生 > >tensorflow資料讀取之tfrecords

tensorflow資料讀取之tfrecords

掌握一個深度學習框架的用法,從訓練一個模型的流程來看,需要掌握以下幾個步驟:
1. 資料的處理,包括訓練資料轉成網路的輸入,模型引數的儲存與讀取
2. 網路結構的定義,包括網路主體的搭建以及loss的定義
3. solver的定義,也就是如何對網路進行優化
4. 模型評估的定義,也就是對模型訓練結果進行評測

這篇博文主要介紹第一部分,資料處理中的訓練資料讀取部分
tensorflow當中讀取資料的方式一共有三種:
1. 供給資料(feeding):在定義graph的時候,可以用placeholder來表示其中的某些節點,在執行的時候,通過向這些節點填入資料來執行整個graph
2. 預載入資料:僅僅適用於可以完全載入到儲存器中的小的資料集
3. 從檔案中載入資料:在圖執行的時候,通過一個pipeline從檔案中讀取資料,作為網路的輸入

前面兩種方法並不適用於平時的大資料集處理的情況,因此這裡我只介紹第三種方法:從檔案中載入資料,重點介紹通過tensorflow特有的tfrecords進行儲存和讀取的方法

1.tfrecords

tfrecords是一種二進位制檔案,我們可以實現把image資料夾裡面的圖片和label資料夾裡面的標籤讀進來製作成tfrecords檔案,之後就都使用tfrecords進行資料的讀取,能夠更好的利用記憶體,更方便複製和移動,類似caffe裡面的hdf5,下面我們將詳細介紹如果製作tfrecords檔案,以及如何從tfrecords檔案裡面讀取資料

2.tfrecords檔案的製作

每一個訓練樣例在tfrecords裡面叫一個example,tensorflow使用名為tf.train.Example的協議來儲存訓練樣例,一個example就類似一個詞典(dict),它的key名為feature,每一個feature對應的值是tensorflow預定義好的Feature message(必須要是ByteList,FloatList以及Int64List中的一種資料型別),這樣的話一個訓練樣例就可以表示為很多鍵-值對的組合。

example通過SerializeToString()方法將樣例序列化成字串儲存,tensorflow通過TFRecordWriter將這些序列化之後的字串存成tfrecord形式

下面表示一個製作tfrecords的例子

    #image_set可以選擇'訓練'或者'測試',images表示所有的圖片資料,labels表示所有的標籤,包括訓練和測試
    def convert(self, image_set, images, labels):
        filename = os.path.join(self._image_set_path, 'dianli_{:s}.tfrecords'.format(image_set))
        print('writing ', filename)

    #建立一個tfrecords的writer
        writer = tf.python_io.TFRecordWriter(filename)
        if image_set == 'train':
            names = self.train_list
        elif image_set == 'test':
            names = self.test_list
        else:
            raise ValueError

    #max_object表示一張圖最多有多少個object,設定這個引數是為了保證寫入時所有label項的形狀都是固定的
        max_object_num = cfg.TRAIN.MAX_OBJECT

        for ix, name in enumerate(names):
            image_raw = images[ix].tostring()
            single_label = labels[ix]
            if single_label.shape[0] > max_object_num:
                raise ValueError('number of gt object is {:d}, more than max object num!'.format(single_label.shape[0]))

    #加0確保label維度固定,這樣後面就可以使用tf.train.shuffle_batch
            non_obj_label = np.zeros((max_object_num - single_label.shape[0], 6))
            comp_label = np.vstack((single_label, non_obj_label))

    #這裡用(cls,xc,yc,xita,w,h)表示一個斜框標籤,cls表示框裡面目標的類別
            cls = comp_label[:, 0].tolist()
            xc = comp_label[:, 1].tolist()
            yc = comp_label[:, 2].tolist()
            xita = comp_label[:, 3].tolist()
            w = comp_label[:, 4].tolist()
            h = comp_label[:, 5].tolist()

    #image_name和image序列化成bytelist,其他的按照本來的資料型別選擇序列化成float或者int64
            example = tf.train.Example(features=tf.train.Features(feature={
                'image_name': self._bytes_feature(names[ix].encode('utf8')),
                'height': self._int64_feature(images[ix].shape[0]),
                'width': self._int64_feature(images[ix].shape[1]),
                'depth': self._int64_feature(images[ix].shape[2]),
                'xc': self._float_feature(xc),
                'yc': self._float_feature(yc),
                'xita': self._float_feature(xita),
                'w': self._float_feature(w),
                'h': self._float_feature(h),
                'cls': self._int64_feature([int(c) for c in cls]),
                'image_raw': self._bytes_feature(image_raw)
            }))

    #儲存一個example
            writer.write(example.SerializeToString())

        writer.close()
        print('finish writing ', filename)

為了驗證我們的tfrecords檔案是否寫入正確,我們可以寫一個檢測的測試程式,讀取裡面的example看看

    #使用tf_record_iterator來遍歷每一個example
    with tf.Session() as sess:
        example = tf.train.Example()

    #train_record表示訓練的tfrecords檔案的路徑
        record_iterator = tf.python_io.tf_record_iterator(path=train_record)
        for record in record_iterator:
            example.ParseFromString(record)
            f = example.features.feature

    #解析一個example
            image_name = f['image_name'].bytes_list.value[0]
            image_raw = f['image_raw'].bytes_list.value[0]
            xc = np.array(f['xc'].float_list.value)[:, np.newaxis]
            yc = np.array(f['yc'].float_list.value)[:, np.newaxis]
            xita = np.array(f['xita'].float_list.value)[:, np.newaxis]
            w = np.array(f['w'].float_list.value)[:, np.newaxis]
            h = np.array(f['h'].float_list.value)[:, np.newaxis]
            label = np.hstack((xc, yc, xita, w, h))

    #將label畫在image上,同時列印image_name,檢視三者是不是對應上的
            print(image_name.encode('utf-8'))
            img_1d = np.fromstring(image_raw, dtype=np.uint8)
            img_2d = img_1d.reshape((480, 480, -1))
            draw_bboxes(img_2d, label)

3.tfrecords檔案的解析和讀取

我們其實可以直接用上面測試的程式來一個一個讀取tfrecords檔案裡面的example,作為網路的輸入,但是這樣一來,網路訓練一個batch的時間就包括從tfrecords裡面讀取一個batch資料的時間,以及這個batch做前向反向的時間,顯然資料讀取的時間是可以避免的。

最好的情況是,網路每次做完一次前向反向之後,不需要等待資料的讀取,也就是說我們可以開兩個執行緒,一個執行緒在進行網路的前向反向計算,另外一個執行緒專門進行資料的讀取,這樣的話兩個執行緒都不會存在所謂的等待狀態。

下面將分別介紹tfrecords檔案的解析以及如何高效得進行資料的讀取

解析部分程式碼如下所示:

    #filename_queue表示一個檔名佇列,後面會講到,比如我需要解析train.tfrecords檔案的話,傳入的就應該是訓練的檔名佇列
    def read_and_decode(self, filename_queue):
        """
        read and decode
        """

    #建立一個tfrecords檔案讀取器並讀取檔案
        reader = tf.TFRecordReader()
        _, serialized_example = reader.read(filename_queue)

    #tensorflow使用parse_single_example來解析example
        features = tf.parse_single_example(
            serialized_example,
            features={

    #對於單個元素的變數,我們使用FixlenFeature來讀取,需要指明變數儲存的資料型別
                # meta data
                'image_name': tf.FixedLenFeature([], tf.string),
                'height': tf.FixedLenFeature([], tf.int64),
                'width': tf.FixedLenFeature([], tf.int64),
                'depth': tf.FixedLenFeature([], tf.int64),

    #對於list型別的變數,我們使用VarLenFeature來讀取,同樣需要指明讀取變數的型別
                # label part
                'xc': tf.VarLenFeature(tf.float32),
                'yc': tf.VarLenFeature(tf.float32),
                'xita': tf.VarLenFeature(tf.float32),
                'w': tf.VarLenFeature(tf.float32),
                'h': tf.VarLenFeature(tf.float32),
                'cls': tf.VarLenFeature(tf.int64),

                # dense data
                'image_raw': tf.FixedLenFeature([], tf.string),
            }
        )

        # meta data
        image_name = tf.cast(features['image_name'], tf.string)
        height = tf.cast(features['height'], tf.int32)
        width = tf.cast(features['width'], tf.int32)
        depth = tf.cast(features['depth'], tf.int32)

    #VarLenFeature得到的是sparse tensor需要轉換一下
        # label part
        xc = tf.sparse_tensor_to_dense(features['xc'])
        yc = tf.sparse_tensor_to_dense(features['yc'])
        xita = tf.sparse_tensor_to_dense(features['xita'])
        w = tf.sparse_tensor_to_dense(features['w'])
        h = tf.sparse_tensor_to_dense(features['h'])
        cls = tf.sparse_tensor_to_dense(features['cls'])

        xc = tf.reshape(xc, shape=[cfg.TRAIN.MAX_OBJECT, 1])
        yc = tf.reshape(yc, shape=[cfg.TRAIN.MAX_OBJECT, 1])
        xita = tf.reshape(xita, shape=[cfg.TRAIN.MAX_OBJECT, 1])
        w = tf.reshape(w, shape=[cfg.TRAIN.MAX_OBJECT, 1])
        h = tf.reshape(h, shape=[cfg.TRAIN.MAX_OBJECT, 1])
        cls = tf.reshape(cls, shape=[cfg.TRAIN.MAX_OBJECT, 1])

        cls = tf.cast(cls, tf.float32)
        label = tf.concat(values=[cls, xc, yc, xita, w, h], axis=1)

    #set_shape是為了固定label的維度,方便後面使用tf.train.shuffle_batch
        label.set_shape([cfg.TRAIN.MAX_OBJECT, 6])

    #對於影象資料需要使用decode_raw解碼,同樣也需要set_shape
        # dense data
        image = tf.decode_raw(features['image_raw'], tf.uint8)
        image_shape = tf.stack([height, width, 3])
        image = tf.reshape(image, image_shape)
        image.set_shape([cfg.TRAIN.IMG_SIZE[0], cfg.TRAIN.IMG_SIZE[1], 3])

        return image, label

下面再詳細介紹一下高效讀取的方法
前面說了,最高效的讀取方法應該是一個執行緒專門讀取資料,一個執行緒專門做訓練(前向反向傳播),讀取資料的執行緒應該維護一個佇列(queue),不斷得讀取資料,壓入佇列,tensorflow裡面常用的是FIFO queue。訓練的執行緒每次從這個佇列裡面讀取一個batch的訓練樣例用來訓練。

tensorflow裡面用的方法比上面的方法要稍微複雜一點,區別在於其維護了兩個佇列,第一個佇列存的是訓練樣例的檔名,第二個佇列存的才是真正的訓練樣例。整個讀取的流程如下圖所示:

1

我們可以使用tf.string_input_produce建立上述的檔名佇列(filename queue),在建立的時候,我們可以指定引數num_epoch的值,這個值顯然是控制檔名佇列裡的檔名一共用來訓練幾個epoch,然後我們可以通過上面read_and_decode函式讀取檔名佇列的檔名,進行解析,將解析得到的訓練樣例壓入訓練樣例佇列(example queue)。最後我們進行訓練的時候每次可以從訓練樣例佇列裡面讀取一個batch,可以使用tf.train.shuffle_batch來獲取一個隨機打亂順序的batch,程式碼樣例如下:

def inputs(dataset_, image_set, batch_size, num_epochs):
    """
    read batch of data fro tf-record files for epoch
    :param dataset_: constructed dataset
    :param image_set:train or test
    :param batch_size:
    :param num_epochs:
    :return:images and labels for a single batch
    """
    train_record = os.path.join(dataset_.image_set_path, 'dianli_train.tfrecords')
    test_record = os.path.join(dataset_.image_set_path, 'dianli_test.tfrecords')
    if not (os.path.exists(train_record) and os.path.exists(test_record)):
        dataset_.create_tfrecord()

    filename = train_record if image_set == 'train' else test_record
    filename_queue = tf.train.string_input_producer([filename], num_epochs=num_epochs)

    image, label = dataset_.read_and_decode(filename_queue)

    #num_thread可以選擇用幾個執行緒同時讀取example queue,min_after_dequeue表示讀取一次之後佇列至少需要剩下的樣例數目,capacity表示佇列的容量
    images, sparse_labels = tf.train.shuffle_batch([image, label], batch_size=batch_size, num_threads=1, capacity=1000 + 3 * batch_size, min_after_dequeue=1000)
    return images, sparse_labels

if __name__ == '__main__':
    image_path = '/home/user/myDocuments/tf-yolo/data/image/'
    label_path = '/home/user/myDocuments/tf-yolo/data/annotation/'
    image_set_path = '/home/user/myDocuments/tf-yolo/data/'

    d = Dianli(image_path, label_path, image_set_path)
    image, label = inputs(d, 'train', batch_size=1, num_epochs=2)
    # de_file = inputs(d, 'train', batch_size=1, num_epochs=1)

    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

    sess = tf.Session()
    sess.run(init_op)

    #要讓程式執行起來,一定要先start_queue_runners向佇列裡面填入資料
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    try:
        step = 0
        while not coord.should_stop():

            images, labels = sess.run([image, label])
            images = np.squeeze(images, axis=0)
            labels = np.squeeze(labels, axis=0)

            ind = np.where(np.sum(labels, axis=1) != 0)[0]
            labels = labels[ind, 1:]

            draw_bboxes(images, labels)

    except tf.errors.OutOfRangeError:
        print('done')
    finally:
        coord.request_stop()

    #join表示等待各個執行緒關閉
    coord.join(threads)
    sess.close()

在上面的例子當中,我們使用了tf.Coordinator來協調各個執行緒,各個執行緒正常情況下完成各自的任務,如果出現異常則向coordinator報告異常(should_stop==1),這個時候coordinator會將所有的執行緒關閉(request_stop==1)

注意在訓練開始之前,一定要呼叫start_queue_runners來開啟各個佇列的執行緒,否則佇列的內容一直為空,訓練的程序會一直掛著無法執行

以上就是整個從檔案資料製作tfrecords,再到tfrecords讀取以及訓練樣例的輸入的整個過程