tensorflow 輕鬆實現自己的目標檢測
影象分類和目標檢測是計算機視覺兩大模組。相比於影象分類,目標檢測任務更復雜更困難。目標檢測不但要檢測到具體的目標,還要定位目標的具體位置。不過Tensorflow models上大神們的無私奉獻已經使得目標檢測模型平民化,只需要按照特定的格式準備好訓練資料,就可以輕鬆訓練出自己想要的目標檢測模型。本文通過一個例子介紹如何通過Tensorflow models快速構建目標檢測模型。
###準備工作:
-
下載訓練圖片和object標記資料並解壓:
wget http://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz wget http://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz tar -xvf images.tar.gz tar -xvf annotations.tar.gz
這個資料集包含37個種類的貓和狗。
-
在訓練的機器上安裝tensorflow 1.4版本(具體安裝方法參考 Tensorflow 官網),在tf1.4新增了目標檢測非極大值抑制,結果展示能模組,模型的訓練在tf1.2及以上版本都可以正常執行;
###目標檢測模型訓練:
目標檢測模型的訓練很簡單,只需要做一些配置工作告訴程式碼去哪裡讀取資料,結果儲存到哪裡就可以了。
1. 首先對資料做一遍處理,把訓練資料處理成tfrecords的格式:
python object_detection/dataset_tools/create_pet_tf_record.py \ --label_map_path=object_detection/data/pet_label_map.pbtxt \ #object_detection/data目錄下已經有這個檔案了,是訓練資料類別的描述檔案 --data_dir=`you_path/data` \ # 存放訓練資料images和annotations的檔案 --output_dir=`you_path/data` # tfrecords儲存路徑(資料夾)
這行程式碼是執行model/research/object_detection下的資料處理程式碼把你下載的images和annotations處理後的資料儲存到tfrecords檔案。執行結束後,你會發現you_path/data目錄下多了兩個檔案:pet_train_with_masks.record、pet_val_with_masks.record。這兩個檔案是用來進行模型訓練和驗證的。
2. 下載預訓練好的模型:
COCO-pretrained Faster R-CNN with Resnet-101 model:
wget http://storage.googleapis. com/download.tensorflow.org/models/object_detection/faster_rcnn_resnet101_coco_2017_11_08.tar.gz
新的模型不會從0開始訓練,而是在這個模型的基礎上進行調整,也就是一種遷移學習的方法。
3. 配置檔案路徑:
在object_detection/samples/configs目錄下編輯faster_rcnn_resnet101_pets.config,只要修改下面幾行,改成我們對應的檔案路徑即可:
(1)
fine_tune_checkpoint: "/data/pets/faster_rcnn_resnet101_coco_2017_11_08/model.ckpt" # 上面第二步下載的預訓練模型
(2)
train_input_reader: {
tf_record_input_reader {
input_path: "/data/pets/pet_train_with_masks.record" #預處理資料生成的tfrecords檔案
}
label_map_path: tensorflow/models/research/object_detection/data/pet_label_map.pbtxt" #這個檔案在預處理資料時也用到過,用來說明訓練樣本的類別資訊
}
(3)
eval_input_reader: {
tf_record_input_reader {
input_path: "/data/pets/pet_val_with_masks.record" #預處理資料生成的tfrecords檔案
}
label_map_path: "/home/recsys/hzwangjian1/tensorflow/models/research/object_detection/data/pet_label_map.pbtxt" #這個檔案和(2)中是同一個檔案,用來說明訓練樣本的類別資訊
shuffle: false
num_readers: 1
}
4. 開始訓練模型:
python3 models/research/object_detection/train.py --logtostderr --train_dir=data/ --pipeline_config_path=object_detection/samples/configs/faster_rcnn_resnet101_pets.config
–train_dir:模型的checkpoint和summary都會儲存在這個路徑下
–pipeline_config_path:上一步的配置檔案
5. 模型匯出和預測新的圖片:
首先把圖匯出到一個.pb檔案,然後可以直接載入.pb檔案恢復整個模型。
python3 models/research/object_detection/export_inference_graph.py \
--input_type image_tensor \
--pipeline_config_path object_detection/samples/configs/faster_rcnn_resnet101_pets.config \
--trained_checkpoint_prefix data/model.ckpt-15000 \
--output_directory object_detection_graph
–pipeline_config_path:上一步訓練時使用的配置檔案,也可以與訓練時使用的配置檔案不同。
–trained_checkpoint_prefix:模型訓練的checkpoint
–output_directory:圖匯出路徑
執行程式碼後會生成.pb檔案,修改執行models/research/object_detection/object_detection_tutorial.ipynb這個檔案就可以來進行預測測試了。