[神經網絡]一步一步使用Mobile-Net完成視覺識別(五)
阿新 • • 發佈:2018-10-29
arc del img fault model coord graph 處理 mobile
1.環境配置
2.數據集獲取
3.訓練集獲取
4.訓練
5.調用測試訓練結果
6.代碼講解
本文是第五篇,講解如何調用測試訓練結果。
上一篇中我們輸出了訓練的模型,這一篇中我們通過調用訓練好的模型來完成測試工作。
在object_detection目錄下創建test.py並輸入以下內容:
import os import cv2 import numpy as np import tensorflow as tf import sys sys.path.append("..") from utils import label_map_util from utils import visualization_utils as vis_util ENERMY= 2 # 1 代表藍色方,2 代表紅色方 ,設置藍色方為敵人 DEBUG = False THRE_VAL = 0.2 PATH_TO_CKPT =‘/home/xueaoru/models/research/inference_graph_v2/frozen_inference_graph.pb‘ PATH_TO_LABELS = ‘/home/xueaoru/models/research/object_detection/car_label_map.pbtxt‘ NUM_CLASSES = 2 label_map = label_map_util.load_labelmap(PATH_TO_LABELS) categories= label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True) category_index = label_map_util.create_category_index(categories) detection_graph = tf.Graph() with detection_graph.as_default(): od_graph_def = tf.GraphDef() with tf.gfile.GFile(PATH_TO_CKPT,‘rb‘) as fid: serialized_graph = fid.read() od_graph_def.ParseFromString(serialized_graph) tf.import_graph_def(od_graph_def, name=‘‘) sess = tf.Session(graph=detection_graph) image_tensor = detection_graph.get_tensor_by_name(‘image_tensor:0‘) detection_boxes = detection_graph.get_tensor_by_name(‘detection_boxes:0‘) detection_scores = detection_graph.get_tensor_by_name(‘detection_scores:0‘) detection_classes = detection_graph.get_tensor_by_name(‘detection_classes:0‘) num_detections = detection_graph.get_tensor_by_name(‘num_detections:0‘) def video_test(): #cap = cv2.VideoCapture(1) cap = cv2.VideoCapture("/home/xueaoru/下載/RoboMaster2.mp4") while(1): time = cv2.getTickCount() ret, image = cap.read() if ret!= True: break image_expanded = np.expand_dims(image, axis=0)#[1,w,h,3] (boxes, scores, classes, num) = sess.run( [detection_boxes, detection_scores, detection_classes, num_detections], feed_dict={image_tensor: image_expanded}) #print(np.squeeze(classes).astype(np.int32)) #print(np.squeeze(scores)) #print(np.squeeze(boxes)) vis_util.visualize_boxes_and_labels_on_image_array( image, np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores), category_index, use_normalized_coordinates=True, line_thickness=8, min_score_thresh=0.4) cv2.imshow(‘Object detector‘, image) key = cv2.waitKey(1)&0xff time = cv2.getTickCount() - time print("處理時間:"+str(time*1000/cv2.getTickFrequency())) if key ==27: break cv2.destroyAllWindows() def pic_test(): image = cv2.imread("/home/xueaoru/models/research/images/image12.jpg") image_expanded = np.expand_dims(image, axis=0) # [1,w,h,3] (boxes, scores, classes, num) = sess.run( [detection_boxes, detection_scores, detection_classes, num_detections], feed_dict={image_tensor: image_expanded}) if DEBUG: vis_util.visualize_boxes_and_labels_on_image_array( image, np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores), category_index, use_normalized_coordinates=True, line_thickness=8, min_score_thresh=0.80) else: score = np.squeeze(scores) max_index = np.argmax(score) score = score[max_index] detected_class = np.squeeze(classes).astype(np.int32)[max_index] if score > THRE_VAL and detected_class == ENERMY: box = np.squeeze(boxes)[max_index]#(ymin,xmin,ymax,xmax) h,w,_ = image.shape min_point = (int(box[1]*w),int(box[0]*h)) max_point = (int(box[3]*w),int(box[2]*h)) cv2.rectangle(image,min_point,max_point,(0,255,255),2) cv2.imshow(‘Object detector‘, image) cv2.waitKey(0) cv2.destroyAllWindows() video_test()
好了,暫時就先這樣吧,最後一篇詳細講解包括通過這些識別到的框到最後計算炮臺偏轉角度的代碼。這段代碼的講解也放在後面。
[神經網絡]一步一步使用Mobile-Net完成視覺識別(五)