使用Tensorflow物體識別API摳出視訊中的豬
阿新 • • 發佈:2019-01-09
豬檢測程式碼以及後續進行豬分類的程式都開源在github了。
主要在官方的demo code上做了如下修改:
- 擴充套件det出的box,以更好地包裹目標,crop時限定不超出影象邊界[expand_ratio]
- 如檢測出pig, animal可能都是對的,可以依據執行結果調整接受規則,抑制檢測到的概率比較大的無關類別,提高魯棒性[class_keep]
- 使用mini batch的方式,以充分利用GPU提高程式執行效率。
下面重點看一下與obj det API有關的核心程式碼:
# Load a (frozen) Tensorflow model into memory
'''
tf.GraphDef():
The GraphDef class is an object created by the ProtoBuf.
詳見https://www.tensorflow.org/extend/tool developers/
graph_def:
A GraphDef proto containing operations to be imported into the default graph
'''
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
'''這裡用了幾個util函式。
'''
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True )
category_index = label_map_util.create_category_index(categories)
'''
重點看定義計算圖。
在這個指令碼中圖片是通過feed_dict={image_tensor: image_np_expanded})傳遞給計算圖的。之前的博文介紹過如何使用自己生成的tfrecord,另外還可以使用tf1.4新出的dataset API。
關於get_tensor_by_name,就是通過名字來獲得張量,具體見下面一段小測試程式碼。
但是還是看不出來為什麼這個計算圖能work,看起來就是獲取了幾個張量,應該就是檢測框等張量依賴於image_tensor,我們去原始碼裡確認一下。發現在object_detection/inference/detection_inference.py檔案中build_inference_graph函式裡,這個函式主要作用是Loads the inference graph and connects it to the input image.
具體如下:
tf.import_graph_def(
graph_def, name='', input_map={'image_tensor': image_tensor})
官方文件:input_map: A dictionary mapping input names (as strings) in graph_def to Tensor objects. The values of the named input tensors in the imported graph will be re-mapped to the respective Tensor values.
再來看看build_inference_graph函式是在哪被呼叫的。然後發現確實在inference資料夾下被呼叫了,但是我們這裡通過feed的方式並不是呼叫這個函式。猜想一定是匯出網路時定義了image_tensor這個變數名,如在object_detection/exporter.py可以看到image_tensor是placeholder,意料之中。至於計算圖具體的連線關係就是模型定義本身了,下次分析訓練的程式碼再看。
'''
with detection_graph.as_default():
with tf.Session(graph=detection_graph) as sess:
# Definite input and output Tensors for detection_graph
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
# Each box represents a part of the image where a particular object was detected.
detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
# Each score represent how level of confidence for each of the objects.
# Score is shown on the result image, together with the class label.
detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')
num_detections = detection_graph.get_tensor_by_name('num_detections:0')
'''注意此處省略了一些程式碼'''
(boxes, scores, classes, num) = sess.run(
[detection_boxes, detection_scores, detection_classes, num_detections],
feed_dict={image_tensor: image_np_expanded})
import tensorflow as tf
c = tf.constant([[1.0, 2.0], [3.0, 4.0]])
d = tf.constant([[1.0, 1.0], [0.0, 1.0]])
e = tf.matmul(c, d, name='example')
with tf.Session() as sess:
test = sess.run(e)
print (e.name) #example:0
print(test)
test = tf.get_default_graph().get_tensor_by_name("example:0")
print (test) #Tensor("example:0", shape=(2, 2), dtype=float32)
print (test.eval())
'''
輸出是:
example_2:0
[[ 1. 3.]
[ 3. 7.]]
Tensor("example:0", shape=(2, 2), dtype=float32)
[[ 1. 3.]
[ 3. 7.]]
'''