1. 程式人生 > 其它 >4.訓練模型之準備訓練資料

4.訓練模型之準備訓練資料

終於要開始訓練識別熊貓的模型了, 第一步是準備好訓練資料,這裡有三件事情要做:

  • 收集一定數量的熊貓圖片。
  • 將圖片中的熊貓用矩形框標註出來。
  • 將原始圖片和標註檔案轉換為TFRecord格式的檔案。

資料標註

收集熊貓的圖片和標註熊貓位置的工作稱之為“Data Labeling”,這可能是整個機器學習領域內最低階、最機械枯燥的工作了,有時候大量的 Data Labeling 工作會外包給專門的 Data Labeling 公司做, 以加快速度和降低成本。 當然我們不會把這個工作外包給別人,要從最底層的工作開始!收集熊貓圖片倒不是太難,從谷歌和百度圖片上收集 200 張熊貓的圖片,應該足夠訓練一個可用的識別模型了。然後需要一些工具來做標註,我使用的是 Mac 版的 RectLabel,常用的還有 LabelImg 和 LabelMe 等。

RectLabel 標註時的介面大概是這樣的:

當我們標註完成的時候,它會在 annotations 目錄下生產和圖片檔名相同的字尾名為 .json 的標註檔案。

開啟一個標註檔案,其內容大概是這樣的:

    {
      "filename" : "61.jpg",
      "folder" : "panda_images",
      "image_w_h" : [
        453,
        340
      ],
      "objects" : [
        {
          "label" : "panda",
          "x_y_w_h" : [
            90,
            104,
            364,
            233
          ]
        }
      ]
    }
  • image_w_h:圖片的寬和高。
  • objects:圖片的中的物體資訊、陣列。
  • label:在標註的時候指定的物體名稱。
  • x_y_w_h:物體位置的矩形框:(xmin、ymin、width、height)。

接下來要做的是耐心的在這 200 張圖片上面標出熊貓的位置,這個稍微要花點時間,可以在 這裡 找已經標註好的圖片資料。

生成 TFRecord

接下來需要一點 Python 程式碼來將圖片和標註檔案生成為 TFRecord 檔案,TFRecord 檔案是由很多tf.train.Example物件序列化以後組成的,先寫由一個單獨的圖片檔案生成tf.train.Example物件的函式:

    def create_sample(image_filename, data_dir):
        image_path = os.path.join(data_dir, image_filename)
        annotation_path = os.path.join(data_dir, 'annotations', os.path.splitext(image_filename)[0] + ".json")
        with tf.gfile.GFile(image_path, 'rb') as fid:
            encoded_jpg = fid.read()
        encoded_jpg_io = io.BytesIO(encoded_jpg)
        with open(annotation_path) as fid:
            image_annotation = json.load(fid)
        width = image_annotation['image_w_h'][0]
        height = image_annotation['image_w_h'][1]
        xmins = []
        ymins = []
        xmaxs = []
        ymaxs = []
        classes = []
        classes_text = []
        for obj in image_annotation['objects']:
            classes.append(1)
            classes_text.append('panda')
            box = obj['x_y_w_h']
            xmins.append(float(box[0]) / width)
            ymins.append(float(box[1]) / height)
            xmaxs.append(float(box[0] + box[2] - 1) / width)
            ymaxs.append(float(box[1] + box[3] - 1) / height)
        filename = image_annotation['filename'].encode('utf8')
        tf_example = tf.train.Example(features=tf.train.Features(feature={
            'image/height': dataset_util.int64_feature(height),
            'image/width': dataset_util.int64_feature(width),
            'image/filename': dataset_util.bytes_feature(filename),
            'image/source_id': dataset_util.bytes_feature(filename),
            'image/encoded': dataset_util.bytes_feature(encoded_jpg),
            'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
            'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
            'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
            'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
            'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
            'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
            'image/object/class/label': dataset_util.int64_list_feature(classes),
        }))
        return tf_example

在這裡簡單說明一下:

  • 通過圖片檔名找到對應的標註檔案,並讀入標註資訊。
  • 因為圖片中標註的物體都是熊貓,用數字 1 來代表,所以 class 數組裡的元素值都為 1,class_text陣列的裡的元素值都為‘panda’。
  • Object Detection API 裡面接受的矩形框輸入格式為 (xmin, ymin, xmax, ymax) 和標註檔案的 (xmin, ymin, width, height) 不一樣,所以要做一下轉換。同時需要將這些值歸一化:將數值投影到 (0, 1] 的區間內。
  • 將特徵組成{特徵名:特徵值}的 dict 作為引數來建立tf.train.Example。

接下來將tf.train.Example物件序列化,我們寫一個可以由圖片檔案列表生成對應 TFRecord 檔案的的函式:

 def create_tf_record(example_file_list, data_dir, output_file_path):
        writer = tf.python_io.TFRecordWriter(output_file_path)
        for filename in example_file_list:
            tf_example = create_sample(filename, data_dir)
            writer.write(tf_example.SerializeToString())
        writer.close()

依次呼叫create_sample函式然後將生成的tf.train.Example物件依次序列化即可。

最後需要將資料集切分為訓練集合測試集,將圖片檔案打亂,然後按照 7:3 的比例進行切分:

    random.seed(42)
    random.shuffle(all_examples)
    num_examples = len(all_examples)
    num_train = int(0.7 * num_examples)
    train_examples = all_examples[:num_train]
    val_examples = all_examples[num_train:]
    create_tf_record(train_examples, data_dir, os.path.join(output_dir, 'train.record'))
    create_tf_record(val_examples, data_dir, os.path.join(output_dir, 'val.record'))

寫完這個指令碼以後,最好再寫一個測試用例來驗證這個指令碼,因為我們將會花很長的時間來訓練,到時候再發現指令碼有 bug 就太浪費時間了,我們主要測試create_sample方法有沒有根據輸入資料生成正確的tf.train.Example物件:

    def test_dict_to_tf_example(self):
        image_file = '61.jpg'
        data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'test_data')
        example = create_sample(image_file, data_dir)
        self._assertProtoEqual(
            example.features.feature['image/height'].int64_list.value, [340])
        self._assertProtoEqual(
            example.features.feature['image/width'].int64_list.value, [453])
        self._assertProtoEqual(
            example.features.feature['image/filename'].bytes_list.value,
            [image_file])
        self._assertProtoEqual(
            example.features.feature['image/source_id'].bytes_list.value,
            [image_file])
        self._assertProtoEqual(
            example.features.feature['image/format'].bytes_list.value, ['jpeg'])
        self._assertProtoEqual(
            example.features.feature['image/object/bbox/xmin'].float_list.value,
            [90.0 / 453])
        self._assertProtoEqual(
            example.features.feature['image/object/bbox/ymin'].float_list.value,
            [104.0/340])
        self._assertProtoEqual(
            example.features.feature['image/object/bbox/xmax'].float_list.value,
            [1.0])
        self._assertProtoEqual(
            example.features.feature['image/object/bbox/ymax'].float_list.value,
            [336.0/340])
        self._assertProtoEqual(
            example.features.feature['image/object/class/text'].bytes_list.value,
            ['panda'])
        self._assertProtoEqual(
            example.features.feature['image/object/class/label'].int64_list.value,
            [1])

後臺回覆“準備訓練資料”關鍵字可以獲取全部原始碼。

完成之後執行指令碼,傳入圖片和標註的資料夾路徑和輸出檔案路徑:

python create_tf_record.py --image_dir=PATH_OF_IMAGE_SET --output_dir=OUTPUT_DIR

執行完成後會在由output_dir引數指定的目錄生成train.record和val.record檔案, 分別為訓練集和測試集。

生成 label map 檔案

最後還需要一個 label map 檔案,很簡單,因為我們只有一種物體:熊貓

label_map.pbtxt:
    item {
      id: 1
      name: 'panda'
    }

訓練一個熊貓識別模型所需要的訓練資料就準備完了,接下來開始在 GPU 主機上面開始訓練。