mean shift演算法
阿新 • • 發佈:2018-12-05
import cv2
import numpy as np
import matplotlib.pyplot as plt
import time
xs,ys,ws,hs = 0,0,0,0 #selection.x selection.y
xo,yo=0,0 #origin.x origin.y
selectObject = False
trackObject = 0 #追蹤目標,0代表沒有,1代表有,-1代表需要更新目標
#滑鼠移動時會進入這個函式
def onMouse(event, x, y, flags, prams):
global xs,ys,ws,hs, selectObject,xo,yo,trackObject
#實時更新滑鼠捕捉矩形大小,設定左上角為起點座標
if selectObject == True:
xs = min(x, xo)
ys = min(y, yo)
ws = abs(x-xo)
hs = abs(y-yo)
if event == cv2.EVENT_LBUTTONDOWN: #按下
xo,yo = x, y #按下時的origin座標
xs,ys,ws,hs= x, y, 0, 0 #矩形大小初始化
selectObject = True #標記開始選中
elif event == cv2.EVENT_LBUTTONUP: #鬆開
selectObject = False
if((ws>0) and (hs>0)):
trackObject = -1 #需要追蹤目標
# 設定初始化的視窗位置
# r,h,c,w = 0,100,0,100 # 設定初試視窗位置和大小
# track_window = (c,r,w,h)
cap = cv2.VideoCapture(0)
cv2.namedWindow('origin') #新建一個視窗
cv2.setMouseCallback('origin',onMouse)
ret, frame= cap.read()
# 設定追蹤的區域
# roi = frame[r:r+h, c:c+w]
# roi區域的hsv影象
hsv_roi = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
# 取值hsv值在(0,60,32)到(180,255,255)之間的部分
mask = cv2.inRange(hsv_roi, np.array((0., 60.,32.)), np.array((180.,255.,255.)))
# 計算直方圖,引數為 圖片(可多),通道數,蒙板區域,直方圖長度,範圍
roi_hist = cv2.calcHist([hsv_roi],[0],mask,[180],[0,180])
# 歸一化
cv2.normalize(roi_hist,roi_hist,0,255,cv2.NORM_MINMAX)
# 設定終止條件,迭代10次或者至少移動1次
term_crit = ( cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 10, 1 )
while(1):
ret, frame = cap.read()
if trackObject != 0:
hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
#cv2.imshow('hsv',hsv)
mask = cv2.inRange(hsv, np.array((0., 30.,10.)), np.array((180.,256.,255.))) #閾值分割
#cv2.imshow('mask',mask)
if trackObject == -1:
hsv_roi = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
track_window = (xs,ys,ws,hs)
maskroi = mask[ys:ys+hs, xs:xs+ws] #在mask影象中擷取track_window
hsv_roi = hsv[ys:ys+hs, xs:xs+ws] #在hsv影象中擷取track_window
roi_hist = cv2.calcHist([hsv_roi],
[0], #計算直方圖的通道,這裡使用顏色計算直方圖,所以就直接使用第一個通道;
maskroi,
[180], #直方圖分成180份
[0.,180.])#表示直方圖中需要統計的各個畫素的值,[0.0, 180.0]表示直方圖能表示RGB所有顏色。
cv2.normalize(roi_hist,roi_hist,0,255,cv2.NORM_MINMAX) #線性歸一化
roi_hist_f=roi_hist.flatten() #value轉成list
plt.bar(range(0,len(roi_hist_f)*2,2), roi_hist_f) #顯示方便,這裡x軸*2,間隔也*2
plt.show()
trackObject=1
dst = cv2.calcBackProject([hsv],[0],roi_hist,[0,180],1)
cv2.imshow('dst',dst)
ret, track_window = cv2.meanShift(dst, track_window, term_crit)
x,y,w,h = track_window
frame = cv2.rectangle(frame, (x,y), (x+w,y+h), 255,2)
cv2.imshow('origin',frame)
if cv2.waitKey(10)==27:
break
cap.release()
cv2.destroyAllWindows()