1. 程式人生 > >Faster批量測試且所有類檢測結果都顯示在一張圖上。

Faster批量測試且所有類檢測結果都顯示在一張圖上。

endernewton版本tensorflow實現的faster-rcnn

原來demo.py:實現的是檢測一張圖片,然後對該圖片的每一類檢測結果,單獨顯示。

修改之後:從txt中讀取要檢測的圖片名稱,進行批量檢測,並把所有類的檢測結果都放到一張圖上,然後儲存到data/result裡。

#!/usr/bin/env python

# --------------------------------------------------------
# Tensorflow Faster R-CNN
# Licensed under The MIT License [see LICENSE for details]
# Written by Xinlei Chen, based on code from Ross Girshick
# --------------------------------------------------------

"""
Demo script showing detections in sample images.

See README.md for installation instructions before running.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import _init_paths
from model.config import cfg
from model.test import im_detect
from model.nms_wrapper import nms

from utils.timer import Timer
import tensorflow as tf
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import os, cv2
import argparse


from nets.vgg16 import vgg16
from nets.resnet_v1 import resnetv1

CLASSES = ('__background__',  # always index 0
                     'normal bolt','normal bolt-2','normal bolt-3','shim losing','nut losing','nut losing-2','nut losing-3','nut directly loosening','nut directly loosening-2','nut directly loosening-3','nut directly loosening-4','pin loosening','pin closing','visible pin losing','visible pin losing-2','invisible pin losing','invisible pin losing-2')

NETS = {'vgg16': ('vgg16_faster_rcnn_iter_70000.ckpt',),'res101': ('res101_faster_rcnn_iter_40000.ckpt',)}
DATASETS= {'pascal_voc': ('voc_2007_trainval',),'pascal_voc_0712': ('voc_2007_trainval+voc_2012_trainval',)}
fi2=open('/home/omnisky/q/tf-faster-rcnn-master/work.txt','w')
def vis_detections(image_name, im, class_name, dets, thresh=0.5):#不用這個函數了。
    """Draw detected bounding boxes."""
    inds = np.where(dets[:, -1] >= thresh)[0]
    if len(inds) == 0:
        return
    im = im[:, :, (2, 1, 0)]
    fig, ax = plt.subplots(figsize=(12, 12))
    ax.imshow(im, aspect='equal')
    for i in inds:
        bbox = dets[i, :4]
        score = dets[i, -1]

        ax.add_patch(
            plt.Rectangle((bbox[0], bbox[1]),
                          bbox[2] - bbox[0],
                          bbox[3] - bbox[1], fill=False,
                          edgecolor='red', linewidth=1.5)
            )
        ax.text(bbox[0], bbox[1] - 2,
                '{:s} {:.3f}'.format(class_name, score),
                bbox=dict(facecolor='blue', alpha=0.5),
                fontsize=14, color='white')

    ax.set_title(('{} detections with '
                  'p({} | box) >= {:.1f}').format(class_name, class_name,
                                                  thresh),
                  fontsize=14)
#    plt.axis('off')
#    plt.tight_layout()
#    plt.draw()
#    image_name=image_name.replace('jpg','png')
#    plt.savefig('/home/omnisky/q/tf-faster-rcnn-master/data/result/'+image_name)
#    print("save image to /home/omnisky/q/tf-faster-rcnn-master/data/result/{}".format(image_name))

def demo(image_name, sess, net):
    """Detect object classes in an image using pre-computed object proposals."""

    # Load the demo image
    im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)
    im = cv2.imread(im_file)
    # Detect all object classes and regress object bounds
    timer = Timer()
    timer.tic()
    scores, boxes = im_detect(sess, net, im)
    timer.toc()
    print('Detection took {:.3f}s for {:d} object proposals'.format(timer.total_time, boxes.shape[0]))

    # Visualize detections for each class
    CONF_THRESH = 0.7
    thresh=0.7
    NMS_THRESH = 0.3
#開啟圖片
    im = im[:, :, (2, 1, 0)]
    fig, ax = plt.subplots(figsize=(12, 12))
    ax.imshow(im, aspect='equal', alpha=0.5)
#對每一類的每一個目標,在圖片上生成框
    for cls_ind, cls in enumerate(CLASSES[1:]):
        cls_ind += 1 # because we skipped background
        cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
        cls_scores = scores[:, cls_ind]
        dets = np.hstack((cls_boxes,
                          cls_scores[:, np.newaxis])).astype(np.float32)
        keep = nms(dets, NMS_THRESH)
        dets = dets[keep, :]
#        vis_detections(image_name, im, cls, dets, thresh=CONF_THRESH)#這個函式註釋掉了,用下面的。

        inds = np.where(dets[:, -1] >= thresh)[0]
        if len(inds) == 0:
            continue
        for i in inds:
            bbox = dets[i, :4]
            score = dets[i, -1]
    
            ax.add_patch(
                plt.Rectangle((bbox[0], bbox[1]),
                              bbox[2] - bbox[0],
                              bbox[3] - bbox[1], fill=False,
                              edgecolor='red', linewidth=1.5)
                )
            ax.text(bbox[0], bbox[1] - 2,
                '{:s} {:.3f}'.format(cls, score),
                bbox=dict(facecolor='blue', alpha=0.5),
                fontsize=14, color='white')
    plt.axis('off')
    plt.tight_layout()
    plt.draw()
    image_name=image_name.replace('jpg','png')
    plt.savefig('/home/omnisky/q/tf-faster-rcnn-master/data/result/'+image_name)
    print("save image to /home/omnisky/q/tf-faster-rcnn-master/data/result/{}".format(image_name))

def parse_args():
    """Parse input arguments."""
    parser = argparse.ArgumentParser(description='Tensorflow Faster R-CNN demo')
    parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',
                        choices=NETS.keys(), default='res101')
    parser.add_argument('--dataset', dest='dataset', help='Trained dataset [pascal_voc pascal_voc_0712]',
                        choices=DATASETS.keys(), default='pascal_voc_0712')
    args = parser.parse_args()

    return args

if __name__ == '__main__':
    cfg.TEST.HAS_RPN = True  # Use RPN for proposals
    args = parse_args()

    # model path
    demonet = args.demo_net
    dataset = args.dataset
    tfmodel = ('/home/omnisky/q/tf-faster-rcnn-master/output/res101/voc_2007_trainval/default/res101_faster_rcnn_iter_40000.ckpt')


    if not os.path.isfile(tfmodel + '.meta'):
        raise IOError(('{:s} not found.\nDid you download the proper networks from '
                       'our server and place them properly?').format(tfmodel + '.meta'))

    # set config
    tfconfig = tf.ConfigProto(allow_soft_placement=True)
    tfconfig.gpu_options.allow_growth=True

    # init session
    sess = tf.Session(config=tfconfig)
    # load network
    if demonet == 'vgg16':
        net = vgg16()
    elif demonet == 'res101':
        net = resnetv1(num_layers=101)
    else:
        raise NotImplementedError
    net.create_architecture("TEST", 18,
                          tag='default', anchor_scales=[8, 16, 32])
    saver = tf.train.Saver()
    saver.restore(sess, tfmodel)

    print('Loaded network {:s}'.format(tfmodel))
#讀取txt,迴圈檢測。
    fi=open('/home/omnisky/q/tf-faster-rcnn-master/data/VOCdevkit2007/VOC2007/ImageSets/Main/test.txt')

    txt=fi.readlines()
    im_names = []
    for line in txt:
        line=line.strip('\n')
        line=line.replace('\r','')
        line=(line+'.jpg')
        im_names.append(line)
    print(im_names)
    fi.close()
    for im_name in im_names:
        print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
        print('Demo for data/demo/{}'.format(im_name))
        demo(im_name, sess, net)
     
    plt.show()#建議註釋掉,不然一次圖片全部顯示容易宕機。