Tensorflow object detection API輸出識別物件的座標等資訊
阿新 • • 發佈:2019-01-07
Tensorflow object detection API輸出識別物件的座標等資訊
此程式基於Tensorflow object detection API。
如果你不想安裝API,但想使用此程式,請先完成
https://blog.csdn.net/weixin_43599336/article/details/83713790
視訊演示地址:https://www.bilibili.com/video/av32418677/?p=3
# By Bend_Function
# https://space.bilibili.com/275177832
# 視訊演示地址:[https://www.bilibili.com/video/av32418677/?p=3]
# 模型下載地址 https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md
# 可以放在任何資料夾下執行(前提正確配置API[環境變數])
import time
start = time.time()
import numpy as np
import os
import tensorflow as tf
from PIL import Image
from object_detection.utils import label_map_util
os.environ[ 'TF_CPP_MIN_LOG_LEVEL'] = '2'
# 要改的內容
###############################################
PATH_TO_CKPT = 'model\\ssd_mobilenet_v1_graph.pb' # 模型及標籤地址
PATH_TO_LABELS = 'model\\mscoco_label_map.pbtxt'
NUM_CLASSES = 90 # 檢測物件個數
PATH_TO_TEST_IMAGES_DIR = os.getcwd()+'\\test_images' # 測試圖片路徑
confident = 0.5 # 置信度,即scores>confident的目標才被輸出
###############################################
# Load a (frozen) Tensorflow model into memory.
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
# Loading label map
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)
os.chdir(PATH_TO_TEST_IMAGES_DIR)
TEST_IMAGE_PATHS = os.listdir(PATH_TO_TEST_IMAGES_DIR)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with detection_graph.as_default():
with tf.Session(graph=detection_graph, config=config) as sess:
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')
num_detections = detection_graph.get_tensor_by_name('num_detections:0')
for image_path in TEST_IMAGE_PATHS: # 開始檢測
image = Image.open(image_path) # 讀圖片
width, height = image.size
image_np = np.array(image)
image_np_expanded = np.expand_dims(image_np, axis=0)
# Actual detection.
(boxes, scores, classes, num) = sess.run(
[detection_boxes, detection_scores, detection_classes, num_detections],
feed_dict={image_tensor: image_np_expanded})
s_boxes = boxes[scores > confident]
s_classes = classes[scores > confident]
s_scores = scores[scores > confident]
for i in range(len(s_classes)):
name = image_path.split("\\")[-1]
# name = image_path.split("\\")[-1].split('.')[0] # 不帶字尾
ymin = s_boxes[i][0] * height # ymin
xmin = s_boxes[i][1] * width # xmin
ymax = s_boxes[i][2] * height # ymax
xmax = s_boxes[i][3] * width # xmax
score = s_scores[i]
if s_classes[i] in category_index.keys():
class_name = category_index[s_classes[i]]['name'] # 得到英文class名稱
print("name:", name)
print("ymin:", ymin)
print("xmin:", xmin)
print("ymax:", ymax)
print("xmax:", xmax)
print("score:", score)
print("class:", class_name)
print("################")
end = time.time()
print("Execution Time: ", end - start)