matlab+BP神經網路實現手寫體數字識別
個人部落格文章連結:http://www.huqj.top/article?id=168
接著上一篇所說的 BP神經網路,現在用它來實現一個手寫體數字的識別程式,訓練素材來自吳恩達機器學習課程,我把打包好上傳到了網盤上:
1 2 |
|
訓練資料一共有5000條,10個數字(0~9,為了和matlab適配,0在這裡統一用10表示),每個數字各500個手寫體圖片,畫素統一處理為20*20,其中pics中是5000張圖片, data是一個.mat檔案,可以直接載入到matlab中,包含兩個變數X(5000x400 double矩陣)和y(5000x1 int矩陣)。
可以看到,訓練資料的輸入是400個畫素點的灰度值,雖然圖片是20x20的,但是為了處理方便將其轉換成1x400的輸入,可以用matlab中的reshape函式進行轉換。而對於輸出而言,這可以看作一個多元分類問題,一共有10種分類,所以輸出可以轉換成一個10維向量。定義好輸入輸出格式之後,再考慮下神經網路的架構,平衡效能和效率,最終選擇的架構是一個25元隱含層的BP網路。另外,為了衡量最終的模型效果,我們需要從5000個數據中抽取一部分作為測試集,這裡我每個數字選了10條資料作為測試資料集,不過理論上訓練集和測試集的比例可以達到 7:3
利用之前編寫好的BP網路訓練函式和一些附加函式(sigmoid,預測函式等),最終的手寫體識別訓練程式如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
|
因為一開始不知道要迭代多少次,所以設定成了一個迴圈的結構,可以根據訓練誤差決定繼續訓練或者結束訓練,然後將模型權重儲存下來,下次可以接著訓練。
如果想要在matlab中畫出圖片,可以將這一行的註釋去掉:
1 |
|
然後繪出所有測試集的圖片如下:
執行程式反覆迭代上萬次之後,在測試集上的準確率穩定在92%左右,這可能也是受模型和資料集的限制。而且這個模型只是用於黑底白字的圖片,用我自己的手寫數字測試效果並不太好(可能與我的圖片處理有關),最高只能達到 7/10 的準確率,後續會持續考慮改進模型。
完整程式碼下載地址: https://download.csdn.net/download/qq_32216775/10897369