TensorFlow 物件檢測 API 教程5
阿新 • • 發佈:2019-01-06
TensorFlow 物件檢測 API 教程 - 第5部分:儲存和部署模型
在本教程的這一步,認為已經選擇了預先訓練的物件檢測模型,調整現有的資料集或建立自己的資料集,並將其轉換為 TFRecord
檔案,修改模型配置檔案並開始訓練。但是,現在需要儲存模型並將其部署到專案中。
一. 將檢查點模型 (.ckpt)
儲存為 .pb
檔案
回到 TensorFlow
物件檢測資料夾,並將 export_inference_graph.py
檔案複製到包含模型配置檔案的資料夾中。
python export_inference_graph.py --input_type image_tensor --pipeline_config_path ./rfcn_resnet101_coco.config --trained_checkpoint_prefix ./models/train/model.ckpt-5000 --output_directory ./fine_tuned_model
這將建立一個新的目錄 fine_tuned_model
,其中模型名為 frozen_inference_graph.pb
。
二.在專案中使用模型
在本指南中一直在研究的專案是建立一個交通燈分類器。在 Python 中,可以將這個分類器作為一個類來實現。在類的初始化部分中,可以建立一個 TensorFlow
class TrafficLightClassifier(object):
def __init__(self):
PATH_TO_MODEL = 'frozen_inference_graph.pb'
self.detection_graph = tf.Graph()
with self.detection_graph.as_default():
od_graph_def = tf.GraphDef()
# Works up to here.
with tf.gfile.GFile(PATH_TO_MODEL, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
self.image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0')
self.d_boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0')
self.d_scores = self.detection_graph.get_tensor_by_name('detection_scores:0')
self.d_classes = self.detection_graph.get_tensor_by_name('detection_classes:0')
self.num_d = self.detection_graph.get_tensor_by_name('num_detections:0')
self.sess = tf.Session(graph=self.detection_graph)
在這個類中,建立了一個函式,在影象上執行分類,並返回影象中分類的邊界框,分數和類。
def get_classification(self, img):
# Bounding Box Detection.
with self.detection_graph.as_default():
# Expand dimension since the model expects image to have shape [1, None, None, 3].
img_expanded = np.expand_dims(img, axis=0)
(boxes, scores, classes, num) = self.sess.run(
[self.d_boxes, self.d_scores, self.d_classes, self.num_d],
feed_dict={self.image_tensor: img_expanded})
return boxes, scores, classes, num
此時,需要過濾低於指定分數閾值的結果。結果自動從最高分到最低分,所以這相當容易。用上面的函式返回分類結果,做完以上這些就完成了!
下面可以看到交通燈分類器在行動