Tensorflow object_detection API(一)
Tensorflow object detection API是基於tensorflow的開源框架,可以用於搭建、訓練、使用object detection服務。
object_detection隸屬於Tensorflow models下的research,在下載object_detection的同時,建議下載整個models,有些包並不在object_detection中,而是在同級目錄下。
安裝教程
object_detection API依賴於包protobuf、pillow、lxml、jupyter、matplotlib。
這些包在安裝的過程中有很大可能回報錯,其中最可能是linux系統沒有安裝gcc,或者gcc版本過低或過高。
Tensorflow Object Detection API使用Protobufs來配置模型和訓練引數。在使用框架之前,必須編譯Protobuf庫。這應該通過從下載解壓的models/目錄執行以下命令來完成:
protoc object_detection/protos/*.proto –python_out=.
當在本地執行時,models /和slim目錄應該附加到PYTHONPATH。在查閱了很多資料後,大概有以下幾種方法:
1. 在python的site-package中新增.pth檔案,將models和slim檔案路徑新增
2. 在python程式碼中新增
import sys
sys.path.append('models路徑')
sys.path.append('slim路徑')
以上安裝完畢
安裝測試
可以通過執行以下命令來測試是否正確安裝了Tensorflow Object Detection API:
python object_detection / builders / model_builder_test.py
MSCOCO模型測試
MSCOCO是Microsoft下的coco資料集。有多種物品及其標記,教程中給了SSDmobilenet的模型下載(據說ssd_mobilenet是最快的,但精度最低)
測試程式碼位於object_detection檔案中的object_detection_tutorial.ipynb
(.ipynb使用notebook開啟)。裡面有很詳細的教程。測試影象結果為:
視訊實現
安裝python-opencv(使用apt-get會很簡單)後,目前實現的是單執行緒的物體檢測,以下是全部程式碼:
import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfile
from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image
if tf.__version__ < '1.4.0':
raise ImportError('Please upgrade your tensorflow installation to v1.4.* or later!')
# This is needed to display the images.
%matplotlib inline
# This is needed since the notebook is stored in the object_detection folder.
sys.path.append("..")
from utils import label_map_util
from utils import visualization_utils as vis_util
# What model to download.
MODEL_NAME = 'ssd_mobilenet_v1_coco_2017_11_17'
MODEL_FILE = MODEL_NAME + '.tar.gz'
DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/'
# Path to frozen detection graph. This is the actual model that is used for the object detection.
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'
# List of the strings that is used to add correct label for each box.
PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt')
NUM_CLASSES = 90
#如果是已經下載好的模型,可以註釋掉這一段
opener = urllib.request.URLopener()
opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE)
tar_file = tarfile.open(MODEL_FILE)
for file in tar_file.getmembers():
file_name = os.path.basename(file.name)
if 'frozen_inference_graph.pb' in file_name:
tar_file.extract(file, os.getcwd())
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)
def load_image_into_numpy_array(image):
(im_width, im_height) = image.size
return np.array(image.getdata()).reshape(
(im_height, im_width, 3)).astype(np.uint8)
#以下是修改教程後的程式碼,和教程有所區別
import cv2
cap = cv2.VideoCapture(0) # 開啟0號攝像頭
success = True
font = cv2.FONT_HERSHEY_SIMPLEX
with detection_graph.as_default():
with tf.Session(graph=detection_graph) as sess:
# Definite input and output Tensors for detection_graph
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
# Each box represents a part of the image where a particular object was detected.
detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
# Each score represent how level of confidence for each of the objects.
# Score is shown on the result image, together with the class label.
detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')
num_detections = detection_graph.get_tensor_by_name('num_detections:0')
def returnimage(image_np):
image_np_expanded = np.expand_dims(image_np, axis=0)
# Actual detection.
(boxes, scores, classes, num) = sess.run(
[detection_boxes, detection_scores, detection_classes, num_detections],
feed_dict={image_tensor: image_np_expanded})
# Visualization of the results of a detection.
vis_util.visualize_boxes_and_labels_on_image_array(
image_np,
np.squeeze(boxes),
np.squeeze(classes).astype(np.int32),
np.squeeze(scores),
category_index,
use_normalized_coordinates=True,
line_thickness=8)
return np.array(image)
while success:
success, image = cap.read()
image = returnimage(image)
cv2.imshow("test", image)
if cv2.waitKey(1) & 0xFF == ord('q'):
cv2.imwrite('test.jpg',image)
break
cap.release()
cv2.destroyAllWindows()
執行結果截圖: