Faster批量測試且所有類檢測結果都顯示在一張圖上。
阿新 • • 發佈:2018-12-16
endernewton版本tensorflow實現的faster-rcnn
原來demo.py:實現的是檢測一張圖片,然後對該圖片的每一類檢測結果,單獨顯示。
修改之後:從txt中讀取要檢測的圖片名稱,進行批量檢測,並把所有類的檢測結果都放到一張圖上,然後儲存到data/result裡。
#!/usr/bin/env python # -------------------------------------------------------- # Tensorflow Faster R-CNN # Licensed under The MIT License [see LICENSE for details] # Written by Xinlei Chen, based on code from Ross Girshick # -------------------------------------------------------- """ Demo script showing detections in sample images. See README.md for installation instructions before running. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import _init_paths from model.config import cfg from model.test import im_detect from model.nms_wrapper import nms from utils.timer import Timer import tensorflow as tf import matplotlib.pyplot as plt from PIL import Image import numpy as np import os, cv2 import argparse from nets.vgg16 import vgg16 from nets.resnet_v1 import resnetv1 CLASSES = ('__background__', # always index 0 'normal bolt','normal bolt-2','normal bolt-3','shim losing','nut losing','nut losing-2','nut losing-3','nut directly loosening','nut directly loosening-2','nut directly loosening-3','nut directly loosening-4','pin loosening','pin closing','visible pin losing','visible pin losing-2','invisible pin losing','invisible pin losing-2') NETS = {'vgg16': ('vgg16_faster_rcnn_iter_70000.ckpt',),'res101': ('res101_faster_rcnn_iter_40000.ckpt',)} DATASETS= {'pascal_voc': ('voc_2007_trainval',),'pascal_voc_0712': ('voc_2007_trainval+voc_2012_trainval',)} fi2=open('/home/omnisky/q/tf-faster-rcnn-master/work.txt','w') def vis_detections(image_name, im, class_name, dets, thresh=0.5):#不用這個函數了。 """Draw detected bounding boxes.""" inds = np.where(dets[:, -1] >= thresh)[0] if len(inds) == 0: return im = im[:, :, (2, 1, 0)] fig, ax = plt.subplots(figsize=(12, 12)) ax.imshow(im, aspect='equal') for i in inds: bbox = dets[i, :4] score = dets[i, -1] ax.add_patch( plt.Rectangle((bbox[0], bbox[1]), bbox[2] - bbox[0], bbox[3] - bbox[1], fill=False, edgecolor='red', linewidth=1.5) ) ax.text(bbox[0], bbox[1] - 2, '{:s} {:.3f}'.format(class_name, score), bbox=dict(facecolor='blue', alpha=0.5), fontsize=14, color='white') ax.set_title(('{} detections with ' 'p({} | box) >= {:.1f}').format(class_name, class_name, thresh), fontsize=14) # plt.axis('off') # plt.tight_layout() # plt.draw() # image_name=image_name.replace('jpg','png') # plt.savefig('/home/omnisky/q/tf-faster-rcnn-master/data/result/'+image_name) # print("save image to /home/omnisky/q/tf-faster-rcnn-master/data/result/{}".format(image_name)) def demo(image_name, sess, net): """Detect object classes in an image using pre-computed object proposals.""" # Load the demo image im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name) im = cv2.imread(im_file) # Detect all object classes and regress object bounds timer = Timer() timer.tic() scores, boxes = im_detect(sess, net, im) timer.toc() print('Detection took {:.3f}s for {:d} object proposals'.format(timer.total_time, boxes.shape[0])) # Visualize detections for each class CONF_THRESH = 0.7 thresh=0.7 NMS_THRESH = 0.3 #開啟圖片 im = im[:, :, (2, 1, 0)] fig, ax = plt.subplots(figsize=(12, 12)) ax.imshow(im, aspect='equal', alpha=0.5) #對每一類的每一個目標,在圖片上生成框 for cls_ind, cls in enumerate(CLASSES[1:]): cls_ind += 1 # because we skipped background cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)] cls_scores = scores[:, cls_ind] dets = np.hstack((cls_boxes, cls_scores[:, np.newaxis])).astype(np.float32) keep = nms(dets, NMS_THRESH) dets = dets[keep, :] # vis_detections(image_name, im, cls, dets, thresh=CONF_THRESH)#這個函式註釋掉了,用下面的。 inds = np.where(dets[:, -1] >= thresh)[0] if len(inds) == 0: continue for i in inds: bbox = dets[i, :4] score = dets[i, -1] ax.add_patch( plt.Rectangle((bbox[0], bbox[1]), bbox[2] - bbox[0], bbox[3] - bbox[1], fill=False, edgecolor='red', linewidth=1.5) ) ax.text(bbox[0], bbox[1] - 2, '{:s} {:.3f}'.format(cls, score), bbox=dict(facecolor='blue', alpha=0.5), fontsize=14, color='white') plt.axis('off') plt.tight_layout() plt.draw() image_name=image_name.replace('jpg','png') plt.savefig('/home/omnisky/q/tf-faster-rcnn-master/data/result/'+image_name) print("save image to /home/omnisky/q/tf-faster-rcnn-master/data/result/{}".format(image_name)) def parse_args(): """Parse input arguments.""" parser = argparse.ArgumentParser(description='Tensorflow Faster R-CNN demo') parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]', choices=NETS.keys(), default='res101') parser.add_argument('--dataset', dest='dataset', help='Trained dataset [pascal_voc pascal_voc_0712]', choices=DATASETS.keys(), default='pascal_voc_0712') args = parser.parse_args() return args if __name__ == '__main__': cfg.TEST.HAS_RPN = True # Use RPN for proposals args = parse_args() # model path demonet = args.demo_net dataset = args.dataset tfmodel = ('/home/omnisky/q/tf-faster-rcnn-master/output/res101/voc_2007_trainval/default/res101_faster_rcnn_iter_40000.ckpt') if not os.path.isfile(tfmodel + '.meta'): raise IOError(('{:s} not found.\nDid you download the proper networks from ' 'our server and place them properly?').format(tfmodel + '.meta')) # set config tfconfig = tf.ConfigProto(allow_soft_placement=True) tfconfig.gpu_options.allow_growth=True # init session sess = tf.Session(config=tfconfig) # load network if demonet == 'vgg16': net = vgg16() elif demonet == 'res101': net = resnetv1(num_layers=101) else: raise NotImplementedError net.create_architecture("TEST", 18, tag='default', anchor_scales=[8, 16, 32]) saver = tf.train.Saver() saver.restore(sess, tfmodel) print('Loaded network {:s}'.format(tfmodel)) #讀取txt,迴圈檢測。 fi=open('/home/omnisky/q/tf-faster-rcnn-master/data/VOCdevkit2007/VOC2007/ImageSets/Main/test.txt') txt=fi.readlines() im_names = [] for line in txt: line=line.strip('\n') line=line.replace('\r','') line=(line+'.jpg') im_names.append(line) print(im_names) fi.close() for im_name in im_names: print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~') print('Demo for data/demo/{}'.format(im_name)) demo(im_name, sess, net) plt.show()#建議註釋掉,不然一次圖片全部顯示容易宕機。