1. 程式人生 > >機器學習實戰(4)—— kNN實戰手寫識別系統

機器學習實戰(4)—— kNN實戰手寫識別系統

文章目錄

我:終於到週末了,可以休息一下了!!!來幾把LOL!!!

(叮鈴…叮鈴…叮鈴…)

我:喂,老闆啊?怎麼啦

老闆:小韓啊,在家休息嗎?

我:是啊。

老闆:別休息啦,來加個班,用上次你寫的kNN,做一個手寫識別系統,訓練集和測試集我都發你郵箱了!週日晚上給我!

我:(What???大週末的,你讓我加班,老子不幹了!)行,保證寫出來!

行了行了,週末不休息了,開工!

這次我們要構建一個手寫識別系統,為了簡單,我們就只識別0-9。需要識別的數字已經用圖形處理軟體,處理成具有相同的色彩和大小:寬高是32畫素×32畫素的黑白影象。儘管採用文字格式儲存影象不能有效地利用記憶體空間,但是為了方便我們的理解,我們還是將影象轉換為文字格式。示例如下:

然後,我們來看一下,使用kNN構造手寫識別系統的步驟:

  1. 收集資料:提供文字檔案。
  2. 準備資料:編寫函式classify0(),將影象格式轉換為分類器使用的list格式。
  3. 分析資料:在Python命令提示符中檢查資料,確保它符合要求。
  4. 訓練演算法:此步驟不適用於k-近鄰演算法。
  5. 測試演算法:編寫函式使用提供的部分資料集作為測試樣本,測試樣本與非測試樣本的區別在於測試樣本是已經完成分類的資料,如果預測分類與實際類別不同,則標記為一個錯誤。
  6. 使用演算法:本例沒有完成此步驟,若你感興趣可以構建完整的應用程式,從影象中提取數字,並完成數字識別,美國的郵件分揀系統就是一個實際執行的類似系統。

2.3.1 準備資料:將影象轉換為測試向量

老闆給的訓練集在目錄trainingDigits中,其中包含了大約2000個例子,每個數字大概有200個樣本。測試集在目錄testDigits中,其中大約900個測試資料。截圖如下:

每個文字檔名稱下劃線前的數字代表這個文字檔案所代表數字。比如說0_8.txt代表的是數字0的第9個樣本(從0開始計數)。

為了使用我們先前編寫好的分類器,我們必須將影象格式化處理為一個向量。我們將一個32×32的二進位制影象矩陣轉換為1×1024的向量。

好了,程式碼走起來!我們繼續在kNN.py中編寫函式img2vector,程式碼如下:

def
img2vector(filename): returnVect = zeros((1, 1024)) fr = open(filename) for i in range(32): lineStr = fr.readline() for j in range(32): returnVect[0, 32 * i + j] = int(lineStr[j]) return returnVect

程式碼很簡單,就是將原來32×32轉換成1×1024,這裡我也就不多說什麼了。大家可以自己去測試一下效果。

2.3.2 使用k-近鄰演算法識別手寫數字

上一節我們已經把資料處理成我們想要的格式了,那麼接下來我們就可以將這些資料丟到分類器裡了。直接來看程式碼:

def handwritingClassTest():
    # 1.初始化我們所需要的資料
    hwLabels = []
    trainingFileList = os.listdir('trainingDigits')  # 這裡需要我們提前匯入os模組,listdir可以列出給定目錄下的檔名
    m = len(trainingFileList)  # 獲得訓練樣本數目
    trainingMat = zeros((m, 1024))  # 構造m×1024的矩陣
    
    # 2.迴圈遍歷訓練集中的每個檔案,生成每個數字的向量資訊,儲存在trainingMat中
    for i in range(m):
        fileNameStr = trainingFileList[i]  # 獲得檔名
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])  # 獲得該檔案所代表的數字
        hwLabels.append(classNumStr)  # 將檔案所代表的數字其存放在類別標籤中
        trainingMat[i, :] = img2vector('trainingDigits/%s' % fileNameStr)  # 資料轉換
    
    # 3.遍歷測試資料資料夾,使用kNN進行測試。
    testFileList = os.listdir('testDigits')
    errorCount = 0.0
    mTest = len(testFileList)  # 獲得測試樣本數目
    for i in range(mTest):
        fileNameStr = testFileList[i]  # 獲得檔名
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])  # 獲得該檔案所代表的數字
        vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
        classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)  # 分類
        print('the classifier came back with: %d, the real answer is: %d' % (classifierResult, classNumStr))
        if classifierResult != classNumStr:
            errorCount += 1.0

    print('\nthe total number of errors is: %d' % errorCount)
    print('\nthe total error rate is: %f' % (errorCount / float(mTest)))

上面程式碼也不難,每一步的具體含義我都給大家寫在註釋中了,所以我也就不多說了。

依賴於機器速度,載入資料集可能要花費很長時間,然後函式開始依次測試每個檔案,我們直接來看輸出的結果:

我們使用k-近鄰演算法識別手寫數字資料集,錯誤率為1.2%。

改變變數k的值、修改函式handwritingClassTest隨機選取訓練樣本、改變訓練樣本的數目,都會對k-近鄰演算法的錯誤率產生影響,感興趣的話可以改變這些變數值,觀察錯誤率的變化。

但是,我們需要注意的是,實際使用這個演算法時,演算法的執行效率並不高。原因如下:

  1. 演算法需要為每個測試向量做2000次距離計算,每個距離計算包括了1024個維度浮點運算,總計要執行900次,
  2. 此外,我們還需要為測試向量準備2MB的儲存空間。

2.4 小結

kNN的理論、實戰,我們就講到這裡了,下面我們來總結一下:

  1. k-近鄰演算法是分類資料最簡單最有效的演算法,我們通過兩次實戰講述瞭如何使用k-近鄰演算法構造分類器。
  2. k-近鄰演算法是基於例項的學習,使用演算法時我們必須有接近實際資料的訓練樣本資料。
  3. k-近鄰演算法必須儲存全部資料集,如果訓練資料集的很大,必須使用大量的儲存空間。此外, 由於必須對資料集中的每個資料計算距離值,實際使用時可能非常耗時。
  4. k-近鄰演算法的另一個缺陷是它無法給出任何資料的基礎結構資訊,因此我們也無法知曉平均例項樣本和典型例項樣本具有什麼特徵。

好了,k-近鄰演算法我們就講到這裡,因為是最基礎的,所以用了比較多的篇幅,希望大家能夠慢慢看完,對機器學習先有一個感性的認識。

機器學習的路還很長,加油,沖沖衝!!!


最後,還是熟悉的配方!

歡迎大家關注我的公眾號,有什麼問題也可以給我留言哦!