Python中使用caffe做目標檢測
阿新 • • 發佈:2019-01-30
轉載地址 http://blog.csdn.net/tostq/article/details/73611590
2. 使用Caffe完成影象目標檢測
本節將以一個快速的影象目標檢測網路SSD作為例子,通過Python Caffe來進行影象目標檢測。
必須安裝windows-ssd版本的Caffe,或者自行在caffe專案中新增SSD的新增相關原始碼.
影象目標檢測網路同影象分類網路的大體原理及結構很相似,不過原始影象再經過深度網路後,並不是得到一組反映不同分類種類下概率的向量,而得到若干組位置資訊,其反映不同目標在影象中的位置及相應分類等資訊。但與分類網路的總體實施結構是一致的。
關於SSD的原理,可以參見其論文:Liu W, Anguelov D, Erhan D, et al. SSD : Single shot multibox detector[C]
2.1 準備檔案
deploy.prototxt
: 網路結構配置檔案VGG_VOC0712_SSD_300x300_iter_60000.caffemodel
: 網路權重檔案labelmap_voc.prototxt
: 資料集分類名稱- 測試影象
本文的SSD是在
VOC0712
資料集下進行訓練的,labelmap_voc.prototxt
也是該資料庫下的各目標的名稱,該檔案對於目標檢測網路的訓練任務是必須的,在下節中,我們將重點介紹如何生成LMDB資料庫及Labelmap檔案。
2.2 載入網路
載入網路的方法,目標檢測網路同目標分類網路都是一致的。
caffe_root = '../../' # 網路引數(權重)檔案 caffemodel = caffe_root + 'models/SSD_300x300/VGG_VOC0712_SSD_300x300_iter_60000.caffemodel' # 網路實施結構配置檔案 deploy = caffe_root + 'models/SSD_300x300/deploy.prototxt' labels_file = caffe_root + 'data/VOC0712/labelmap_voc.prototxt' # 網路實施分類 net = caffe.Net(deploy, # 定義模型結構 caffemodel, # 包含了模型的訓練權值 caffe.TEST) # 使用測試模式(不執行dropout)
2.3 測試影象預處理
預處理主要包含兩個部分:
- 減去均值
- 調整大小
# 載入ImageNet影象均值 (隨著Caffe一起釋出的) mu = np.load(caffe_root + 'python/caffe/imagenet/ilsvrc_2012_mean.npy') mu = mu.mean(1).mean(1) # 對所有畫素值取平均以此獲取BGR的均值畫素值 # 影象預處理 transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape}) transformer.set_transpose('data', (2,0,1)) transformer.set_mean('data', mu) transformer.set_raw_scale('data', 255) transformer.set_channel_swap('data', (2,1,0))
2.4 執行網路
- 匯入輸入資料
- 通過forward()執行結果
# 載入影象
im = caffe.io.load_image(img)
# 匯入輸入影象
net.blobs['data'].data[...] = transformer.preprocess('data', im)
start = time.clock()
# 執行測試
net.forward()
end = time.clock()
print('detection time: %f s' % (end - start))
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
2.5 檢視目標檢測結果
SSD網路的最後一層名為'detection_out'
,該層輸出Blob結構'detection_out'
中包含了多組元組結構,每個元組結構包含7個引數,其中第2引數表示分類類別序號,第3個引數表示概率置信度,第4~7個引數分別表示目標區域左上及右下的座標,而元組的個數表明該影象中可能的目標個數。
當然可能不同網路模型的結構不一樣,可能會有不同的設定,但至少對於SSD是這樣設定的。
# 檢視目標檢測結果 # 開啟labelmap_voc.prototxt檔案 file = open(labels_file, 'r') labelmap = caffe_pb2.LabelMap() text_format.Merge(str(file.read()), labelmap) # 得到網路的最終輸出結果 loc = net.blobs['detection_out'].data[0][0] confidence_threshold = 0.5 for l in range(len(loc)): if loc[l][2] >= confidence_threshold: # 目標區域位置資訊 xmin = int(loc[l][3] * im.shape[1]) ymin = int(loc[l][4] * im.shape[0]) xmax = int(loc[l][5] * im.shape[1]) ymax = int(loc[l][6] * im.shape[0]) # 畫出目標區域 cv2.rectangle(im, (xmin, ymin), (xmax, ymax), (55 / 255.0, 255 / 255.0, 155 / 255.0), 2) # 確定分類類別 class_name = labelmap.item[int(loc[l][1])].display_name cv2.putText(im, class_name, (xmin, ymax), cv2.cv.CV_FONT_HERSHEY_SIMPLEX, 1, (55, 255, 155), 2)
2.6 目標檢測結果展示