(TensorFlow)卷積神經網路下與電腦”猜拳“
想法來源於學習DeepLearning.ai的卷積神經網路的作業,想利用手勢識別完成一個可以和電腦進行“猜拳遊戲”。
參考資料:
【1】TensorFlow和樹莓派完成的猜拳遊戲*
【2】手勢識別模型
【3】攝像頭手寫數字識別
第一步:前期準備
1.安裝opencv和tensorflow,本系統實現是在Anaconda中安裝了opencv的工具包加以實現的。
2.學會呼叫opencv庫中的攝像頭進行圖片拍照
3.儲存圖片到對應的資料夾下,便於模型預測。
第二步:收集資料
具體方法參考[1]進行實現,需要注意的是要對拍攝的圖片進行分類儲存訓練。主要是在tensorflow下呼叫openCV的相關操作。
(1)呼叫本地電腦的攝像頭VideoCapture(0)
(2)定義儲存圖片的相關函式,將拍攝的相關手勢圖片儲存。
第三步:資料處理
收集的圖片需要經過處理,才可以變成TensorFlow可以處理的資料,可以通過相關API處理將資料變成可以在TensorFlow直接使用的資料,同時將資料分為訓練集和測試集(80%:20%)
(1)TensorFlow 資料讀取方式——DataSet API
https://blog.csdn.net/kwame211/article/details/78579035
第四步:網路架構
對於此問題,可以看作是一個影象分類問題,利用卷積神經網路進行分類。為了降低實時識別的反應時間,本系統輸入圖片的尺寸保持在28*28畫素(MNIST類似),然後利用小的卷積神經網路加以實現,隱藏單元層數為4:,具體卷積的相關設定如下:
相關分析:卷積過程中採用填充保持影象大小不變。dropout防止過擬合
第五步:訓練模型
(1)匯入訓練資料,對資料做歸一化和扁平化處理,對於y標記結果做one-hot編碼。
歸一化輸入:作為加速訓練神經網路的一種方法,包括兩個步驟:零均值,歸一化方差。這樣可以便於訓練集和測試集通過相同的方差和均值進行定義的資料進行轉換。除此之外,歸一化可以使代價函式平均起來更對稱
one-hot編碼:區域性表徵,多少個概念就需要多少個編碼;與二進位制編碼(分散式表徵)不同的是,二進位制編碼可以用長度為n 的編碼表示2的n次方的概念。通過輸入one-hot編碼的值,然後學習過程讓輸出重現輸入,最後發現當某一個編碼輸入時,如果把隱藏層的單元啟用值輸出,啟用值就可以學到不同的二進位制編碼。
(3)建立佔位符,宣告變數,便於後期傳入訓練資料集
(4) 引數初始化,採用Xavier初始化(np.sqrt(1/n^[l-1]))作為權重的初始值,這是tanh函式初始化值。引數初始化值設定不好會造成梯度消失或者梯度爆炸
(5)Tensorflow的前向傳播。
在該模型中採用ReLu啟用函式進行求解
(6)計算代價函式
其中交叉熵計算,Tensorflow交叉熵計算函式輸入中的logits都不是softmax或sigmoid的輸出,而是softmax或sigmoid函式的輸入,因為它在函式內部進行sigmoid或softmax操作
(7)反向傳播和引數更新
在計算完代價函式之後,我們建立一個 “optimizer” 物件。在tf.session執行的時候,我們需要呼叫該物件,使得”optimizer” 物件在選定的優化演算法和學習率的基礎上進行優化操作。例如,梯度下降演算法優化器(gradient descent the optimizer):
(8)構建模型
將構造好的各個函式應用構建識別模型,進行迭代訓練獲得訓練的模型
(9)攝像頭進行實時識別
主要是通過opencv直接拍攝進行呼叫。
`
cap = cv2.VideoCapture(1)
while(1):
ret, frame = cap.read()
cv2.rectangle(frame,(270,200),(340,270),(0,0,255),2)
cv2.imshow(“capture”, frame)
roiImg = frame[200:270,270:340]
img = cv2.resize(roiImg,(28,28))
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
np_img = img.astype(np.float32)
netoutput = network(np_img)
predictions = sess.run(netoutput,feed_dict={keep_prob: 0.5})
predicts=predictions.tolist() #tensorflow output is numpy.ndarray like [[0 0 0 0]]
label=predicts[0]
result=label.index(max(label))
print('result num:')
print(result)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
直接進行測試。
(10)互動猜拳的實現
主要邏輯是自己設定的,當識別為剪刀的時候,呼叫預先設定好的拳頭,以此類推。
總結
自己的初次嘗試,還存在許多問題,需要後續改進!