Tensorflow物體檢測(Object Detection)
最近工作需要折騰 tensorflow, 學會圖片分類後開始折騰物體檢測。 感謝前人的工作,現在把物體檢跑起來還是比較容易的,但優化就不容易了。
Tensorflow在更新1.2版本之後多了很多新功能,其中放出了很多用tf框架寫的深度網路slim模型,大大降低了開發難度,無論是fine-tuning還是修改網路結構都方便了不少。這裡講的的是物體檢測(object detection)API,這個庫的說明文件很詳細。
這個物體檢測API提供了5種網路結構的預訓練的weights,全部是用COCO資料集進行訓練,可以在這裡下載:分別是SSD+mobilenet, SSD+inception_v2, R-FCN+resnet101, faster RCNN+resnet101, faster RCNN+inception+resnet101。各個模型的精度和計算所需時間如下,具體測評細節可以看
依賴包
Protobuf 2.6
Pillow 1.0
lxml
tf Slim
Jupyter notebook
Matplotlib # 用這個畫圖會比較慢,記憶體佔用高,可以用cv2來代替
Tensorflow
API安裝
$ pip install tensorflow-gpu
$ sudo apt-get install protobuf-compiler python-pil python-lxml
$ sudo pip install jupyter
$ sudo pip install matplotlib
因為使用protobuf來配置模型和訓練引數,所以API正常使用必須先編譯protobuf庫
$ cd tensorflow/models
$ protoc object_detection/protos/*.proto --python_out=.
然後將models和slim(tf高階框架)加入python環境變數:
export PYTHONPATH=$PYTHONPATH:/your/path/to/tensorflow/models:/your/path/to/tensorflow/models/slim
最後測試安裝:
python object_detection/builders/model_builder_test.py
fine-tuning
準備資料集
以Pascal VOC資料集的格式為例:object_detection/create_pascal_tf_record.py
.record
格式python object_detection/create_pascal_tf_record.py \ --label_map_path=object_detection/data/pascal_label_map.pbtxt \ # 訓練物品的品類和id --data_dir=VOCdevkit --year=VOC2012 --set=train \ --output_path=pascal_train.record python object_detection/create_pascal_tf_record.py \ --label_map_path=object_detection/data/pascal_label_map.pbtxt \ --data_dir=VOCdevkit --year=VOC2012 --set=val \ --output_path=pascal_val.record
其中
--data_dir
為訓練集的目錄。結構同Pascal VOC,如下:+ VOCdevkit # +為資料夾 + JPEGImages - 001.jpg # - 為檔案 + Annotations - 001.xml
訓練
train和eval輸入輸出資料儲存結構為:+ input - label_map.pbtxt file # 可以在object_detection/data/*.pbtxt找到樣例 - train TFRecord file - eval TFRecord file + models + modelA - pipeline config file # 可以在object_detection/samples/configs/*.config下找到樣例,定義訓練引數和輸入資料 + train # 儲存訓練產生的checkpoint檔案 + eval
準備好上述檔案後就可以直接呼叫train檔案進行訓練
python object_detection/train.py \ --logtostderr \ --pipeline_config_path=/your/path/to/models/modelA/pipeline config file \ --train_dir=/your/path/to/models/modelA/train
評估
在訓練開始以後,就可以執行eval來評估模型的效果。不過實際情況是eval模型也需要載入ckpt檔案,因此也需要佔用不小的視訊記憶體,而一般訓練的時候都會調整batch儘量利用顯示卡效能,所以想要實時執行train和eval的話需要調整好兩者所需的記憶體。python object_detection/eval.py \ --logtostderr \ --pipeline_config_path=/your/path/to/models/modelA/pipeline config file \ --checkpoint_dir=/your/path/to/models/modelA/train \ --eval_dir=/your/path/to/models/modelA/eval
監控
通過tensorboard命令可以在瀏覽器很輕鬆的監控訓練程序,在瀏覽器輸入localhost:6006
(預設)即可tensorboard --logdir=/your/path/to/models/modelA # 需要包含eval和train目錄(.ckpt, .index, .meta, checkpoint, graph.pbtxt檔案)
freeze model
在訓練完成後需要將訓練產生的最後一組.meta, .index, .ckpt, checkpoint檔案。其中meta儲存了graph和metadata,ckpt儲存了網路的weights。而在生產環境中進行預測的時候是隻需要模型和權重,不需要metadata,所以需要將其提出進行freeze操作,將所需的部分放到一個檔案,方便之後的呼叫,也減少模型載入所需的記憶體。(在下載的預訓練模型解壓後可以找到4個檔案,其中名為frozen_inference_graph.pb的檔案就是freeze後產生的模型檔案,比weights檔案大,但是比weights和meta檔案加起來要小不少。)
本來,tensorflow/python/tools/freeze_graph.py
提供了freeze model的api,但是需要提供輸出的final node names(一般是softmax之類的最後一層的啟用函式命名),而object detection api提供提供了預訓練好的網路,final node name並不好找,所以object_detection
目錄下還提供了export_inference_graph.py
。
python export_inference_graph.py \
--input_type image_tensor \
--pipeline_config_path /your/path/to/models/modelA/pipeline config file \
--checkpoint_path /your/path/to/models/modelA/train/model.ckpt-* \
--inference_graph_path /your/path/to/models/modelA/train/frozen_inference_graph.pb # 輸出的檔名
模型呼叫
目錄下提供了一個樣例。這裡只是稍作調整用cv2來顯示影象。
也可以直接使用官方提供的https://github.com/tensorflow/models/blob/master/research/object_detection/object_detection_tutorial.ipynb ,使用jupyter notebook測試。
import numpy as np
import os, sys
import tensorflow as tf
import cv2
MODEL_ROOT = "/home/arkenstone/tensorflow/workspace/models"
sys.path.append(MODEL_ROOT) # 應用和訓練的目錄在不同的地方
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util
MODEL_PATH = "/home/arkenstone/tensorflow/workspace/models/objectdetection/models/faster_rcnn_inception_resnet_v2_atrous_coco_11_06_2017"
PATH_TO_CKPT = MODEL_PATH + '/frozen_inference_graph.pb' # frozen model path
PATH_TO_LABELS = os.path.join(MODEL_ROOT, 'object_detection/data', 'mscoco_label_map.pbtxt')
NUM_CLASSES = 90
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories) # 格式為{1:{'id': 1, 'name': 'person'}, 2: {'id': 2, 'name': 'bicycle'}, ...}
# 模型載入:test.py
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
gpu_memory_fraction = 0.4
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_memory_fraction)
config = tf.ConfigProto(gpu_options=gpu_options, log_device_placement=False, allow_soft_placement=True)
config.gpu_options.allow_growth = False
def detect(image_path):
with detection_graph.as_default(): # 需要手動close sess
with tf.Session(graph=detection_graph, config=config) as sess:
image = cv2.imread(image_path)
image_np = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image_np_expanded = np.expand_dims(image_np, axis=0)
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
scores = detection_graph.get_tensor_by_name('detection_scores:0')
classes = detection_graph.get_tensor_by_name('detection_classes:0')
num_detections = detection_graph.get_tensor_by_name('num_detections:0')
(boxes, scores, classes, num_detections) = sess.run(
[boxes, scores, classes, num_detections],
feed_dict={image_tensor: image_np_expanded})
vis_util.visualize_boxes_and_labels_on_image_array(
image_np,
np.squeeze(boxes),
np.squeeze(classes).astype(np.int32),
np.squeeze(scores),
category_index,
use_normalized_coordinates=True,
line_thickness=4)
new_img = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
cv2.imshow("test", new_img)
cv2.waitKey(0)
if __name__ == '__main__':
detect(/your/test/image)
參考
https://github.com/tensorflow/models/tree/master/research/object_detection