1. 程式人生 > >tensorflow 輕鬆實現自己的目標檢測

tensorflow 輕鬆實現自己的目標檢測

影象分類和目標檢測是計算機視覺兩大模組。相比於影象分類,目標檢測任務更復雜更困難。目標檢測不但要檢測到具體的目標,還要定位目標的具體位置。不過Tensorflow models上大神們的無私奉獻已經使得目標檢測模型平民化,只需要按照特定的格式準備好訓練資料,就可以輕鬆訓練出自己想要的目標檢測模型。本文通過一個例子介紹如何通過Tensorflow models快速構建目標檢測模型。

###準備工作

  1. 下載訓練圖片和object標記資料並解壓:

    wget http://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz
    wget http://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz
    tar -xvf images.tar.gz
    tar -xvf annotations.tar.gz
    

    這個資料集包含37個種類的貓和狗。

  2. 在訓練的機器上安裝tensorflow 1.4版本(具體安裝方法參考 Tensorflow 官網),在tf1.4新增了目標檢測非極大值抑制,結果展示能模組,模型的訓練在tf1.2及以上版本都可以正常執行;

###目標檢測模型訓練

      目標檢測模型的訓練很簡單,只需要做一些配置工作告訴程式碼去哪裡讀取資料,結果儲存到哪裡就可以了。

1. 首先對資料做一遍處理,把訓練資料處理成tfrecords的格式:

python object_detection/dataset_tools/create_pet_tf_record.py \
    --label_map_path=object_detection/data/pet_label_map.pbtxt \  #object_detection/data目錄下已經有這個檔案了,是訓練資料類別的描述檔案
    --data_dir=`you_path/data` \   # 存放訓練資料images和annotations的檔案
    --output_dir=`you_path/data`   # tfrecords儲存路徑(資料夾)

      這行程式碼是執行model/research/object_detection下的資料處理程式碼把你下載的images和annotations處理後的資料儲存到tfrecords檔案。執行結束後,你會發現you_path/data目錄下多了兩個檔案:pet_train_with_masks.record、pet_val_with_masks.record。這兩個檔案是用來進行模型訓練和驗證的。

2. 下載預訓練好的模型:

      COCO-pretrained Faster R-CNN with Resnet-101 model:

wget http://storage.googleapis.
com/download.tensorflow.org/models/object_detection/faster_rcnn_resnet101_coco_2017_11_08.tar.gz

新的模型不會從0開始訓練,而是在這個模型的基礎上進行調整,也就是一種遷移學習的方法。

3. 配置檔案路徑:

      在object_detection/samples/configs目錄下編輯faster_rcnn_resnet101_pets.config,只要修改下面幾行,改成我們對應的檔案路徑即可:

(1)

fine_tune_checkpoint: "/data/pets/faster_rcnn_resnet101_coco_2017_11_08/model.ckpt" # 上面第二步下載的預訓練模型

(2)

train_input_reader: {
  tf_record_input_reader {
    input_path: "/data/pets/pet_train_with_masks.record"   #預處理資料生成的tfrecords檔案
  }
  label_map_path: tensorflow/models/research/object_detection/data/pet_label_map.pbtxt"  #這個檔案在預處理資料時也用到過,用來說明訓練樣本的類別資訊
}

(3)

eval_input_reader: {
  tf_record_input_reader {
    input_path: "/data/pets/pet_val_with_masks.record"  #預處理資料生成的tfrecords檔案
  }
 label_map_path: "/home/recsys/hzwangjian1/tensorflow/models/research/object_detection/data/pet_label_map.pbtxt" #這個檔案和(2)中是同一個檔案,用來說明訓練樣本的類別資訊
  shuffle: false
  num_readers: 1
}

4. 開始訓練模型:

python3 models/research/object_detection/train.py --logtostderr --train_dir=data/ --pipeline_config_path=object_detection/samples/configs/faster_rcnn_resnet101_pets.config

–train_dir:模型的checkpoint和summary都會儲存在這個路徑下
–pipeline_config_path:上一步的配置檔案

5. 模型匯出和預測新的圖片:

      首先把圖匯出到一個.pb檔案,然後可以直接載入.pb檔案恢復整個模型。

python3 models/research/object_detection/export_inference_graph.py \
    --input_type image_tensor \
    --pipeline_config_path object_detection/samples/configs/faster_rcnn_resnet101_pets.config \
    --trained_checkpoint_prefix data/model.ckpt-15000 \
    --output_directory object_detection_graph

–pipeline_config_path:上一步訓練時使用的配置檔案,也可以與訓練時使用的配置檔案不同。
–trained_checkpoint_prefix:模型訓練的checkpoint
–output_directory:圖匯出路徑

      執行程式碼後會生成.pb檔案,修改執行models/research/object_detection/object_detection_tutorial.ipynb這個檔案就可以來進行預測測試了。

![圖片識別結果](https://img-blog.csdn.net/20180121230819705?watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvd2FuZ2ppYW4xMjA0/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/SouthEast)

參考資料

備註: tensorflow官網的github做了一些更新,把train.py重新命名成了model_main.py,引數也相應的做了修改。本博文對應的較早版本程式碼我fork了一份:https://github.com/AlexWang90/models