目標檢測(Google object_detection) API 上訓練自己的資料集
應公司要求,利用谷歌最近開源的Google object_detection API對公司收集的資料集進行訓練,並檢測訓練效果。通過一兩天的研究以及維持四天的訓練(GTX 1060 6GB),終於成功的在自己資料集上訓練的任務。測試效果感覺還行,雖沒有達到谷歌官方公佈的資料集上跑的識別效果,但是識別率也還過得去,這主要是因為資料集沒有官方做的那麼規範。下圖為本人挑選的一張識別率較好的圖片(識別哈爾濱啤酒):
下面把本人如何一步步在自己的資料集上訓練的詳細步驟做個總結,一是方便自己以後操作起來更快的再次上手訓練,二是方便大家能好的實現該API的一些需求。
需要說明的:
2:資料集格式需要為轉換成tensorflow要求的tfrecord的形式。
3:本文在GTX 1060 6GB的顯示卡上訓練了四天
4:如何安裝tensorflow等一些依賴庫,本文不再贅述,請參考:安裝依賴庫教程連結
過程:
1:下載Google object_detection API:
下載地址
2:資料集準備:
資料集需要符合API所需的TFRecord格式,官方提供的資料集格式為PASCAL VOC格式,API已經為我們提供了將此格式轉為TFRecord的程式碼. 但是這裡我們需要注意一個細節:create_pascal_tf_record.py中的
examples_path = os.path.join(data_dir, year,去掉'aeroplane_'。'ImageSets', 'Main', 'aeroplane_' + FLAGS.set + '.txt')
同時,將檔案中的
flags.DEFINE_string('label_map_path', 'data/pascal_label_map.pbtxt', 'Path to label map proto')
data/pascal_labe_map.pbtxt改為自己的資料集label
然後在“tensorflow/models/object_detection/”目錄下運行以下命令
- #生成訓練集record
- python create_pascal_tf_record.py --data_dir=`自己的訓練資料集路徑` \
- --year=VOC2007 --set=train --output_path=`你想儲存的訓練集的record路徑`
- #生成驗證集record
- python create_pascal_tf_record.py --data_dir=`自己的驗證資料集路徑` \
- --year=VOC2007 --set=val --output_path=`你想儲存的驗證集的record路徑`
注意,在data目錄下選擇一個.pbtxt檔案,將該檔案改為自己資料集的label。
執行上述兩個命令後會在data資料夾下生成兩個record檔案。
3:下載預訓練模型
按照上文“需要說明的”第一條下載預訓練模型,將下載好的模型進行解壓,並將.ckpt的三個檔案拷貝到models目錄下。將object_detection/samples/configs/faster_rcnn_resnet101_voc07.config複製到models目錄下並做如下修改:
1)num_classes:修改為之前修的的.pbtxt檔案中的類別數目
2)將所有'PATH_TO_BE_CONFIGURED'修改為自己之前設定的路徑
4:開始訓練
執行上述三步之後我們可以開始訓練了,此處需要注意兩點,不然會出現模組匯出錯誤,在tensorflow/models
分別執行:
- protoc object_detection/protos/*.proto --python_out=.
- export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim
然後進入到obeject_detection目錄下,執行一下命令:
- python train.py --train_dir='想要儲存訓練模型的路徑' --pipeline_config_path='你採用的.config檔案路徑'
5:模型視覺化
執行上述四步之後您基本上只需等著模型執行完成即可,如果您想要視覺化您的模型,可以執行:
- tensorboard --logdir=’上面第4點提到的train_dir路徑‘
然後在你的瀏覽器輸入0.0.0.0:6006就能看到模型一些相關引數的視覺化結果了。
訓練完成後會生成三個.cpkt的檔案,將這三個檔案複製到tensorflow/models下,可利用這三個檔案生成一個.pb檔案,生成程式碼如下:
- python object_detection/export_inference_graph.py \
- --input_type image_tensor \
- --pipeline_config_path ’你的.config檔案路徑’ \
- --checkpoint_path model.ckpt-‘CHECKPOINT_NUMBER’ \
- --inference_graph_path output_inference_graph.pb