1. 程式人生 > >KNN演算法——實現手寫數字識別(Sklearn實現)

KNN演算法——實現手寫數字識別(Sklearn實現)

KNN專案實戰——手寫數字識別

1、資料集介紹

需要識別的數字已經使用圖形處理軟體,處理成具有相同的色彩和大小:寬高是32畫素x32畫素的黑白影象。儘管採用本文格式儲存影象不能有效地利用記憶體空間,但是為了方便理解,我們將圖片轉換為文字格式。

數字的文字格式如下:

  

資料集下載:

這些文字格式儲存的數字的檔案命名也很有特點,格式為:數字的值_該數字的樣本序號,如下:

2、準備資料:將影象轉換為測試向量

將每個數字檔案中32*32的二進位制影象矩陣轉換為1*1024的向量,作為一個樣本輸入。

3、程式碼實現

import numpy as np
from os import listdir
from sklearn.neighbors import KNeighborsClassifier as KNN


"""
函式說明:將32x32的二進位制影象轉換為1x1024向量
"""
def img2vector(filename):
    #建立1x1024零向量
    returnVect = np.zeros((1, 1024))
    #開啟檔案
    fr = open(filename)
    #按行讀取
    for i in range(32):
        #讀一行資料
        lineStr = fr.readline()
        #每一行的前32個元素依次新增到returnVect中
        for j in range(32):
            returnVect[0, 32*i+j] = int(lineStr[j])
    #返回轉換後的1x1024向量
    return returnVect


"""
函式說明:手寫數字分類測試
"""
def handwritingClassTest():
    #訓練集的Labels
    hwLabels = []
    #返回trainingDigits目錄下的檔名
    trainingFileList = listdir('trainingDigits')
    #返回資料夾下檔案的個數
    m = len(trainingFileList)
    #初始化訓練的Mat矩陣,訓練集
    trainingMat = np.zeros((m, 1024))
    #從檔名中解析出訓練集的類別
    for i in range(m):
        #獲得檔案的名字
        fileNameStr = trainingFileList[i]
        #獲得分類的數字
        classNumber = int(fileNameStr.split('_')[0])
        #將獲得的類別新增到hwLabels中
        hwLabels.append(classNumber)
        #將每一個檔案的1x1024資料儲存到trainingMat矩陣中
        trainingMat[i,:] = img2vector('trainingDigits/%s' % (fileNameStr))
    #構建kNN分類器
    neigh =KNN(n_neighbors = 3, algorithm = 'auto')
    #擬合模型, trainingMat為訓練矩陣,hwLabels為對應的標籤
    neigh.fit(trainingMat, hwLabels)
    #返回testDigits目錄下的檔案列表
    testFileList = listdir('testDigits')
    #錯誤檢測計數
    errorCount = 0.0
    #測試資料的數量
    mTest = len(testFileList)
    #從檔案中解析出測試集的類別並進行 分類測試
    for i in range(mTest):
        #獲得檔案的名字
        fileNameStr = testFileList[i]
        #獲得分類的數字
        classNumber = int(fileNameStr.split('_')[0])
        #獲得測試集的1x1024向量,用於訓練
        vectorUnderTest = img2vector('testDigits/%s' % (fileNameStr))
        #獲得預測結果
        classifierResult = neigh.predict(vectorUnderTest)
        print("分類返回結果為%d\t真實結果為%d" % (classifierResult, classNumber))
        if(classifierResult != classNumber):
            errorCount += 1.0
    print("總共錯了%d個數據\n錯誤率為%f%%" % (errorCount, errorCount/mTest * 100))


"""
函式說明:main函式
"""
if __name__=='__main__':
    handwritingClassTest()

結果輸出: