機器學習實戰(4)—— kNN實戰手寫識別系統
文章目錄
我:終於到週末了,可以休息一下了!!!來幾把LOL!!!
(叮鈴…叮鈴…叮鈴…)
我:喂,老闆啊?怎麼啦
老闆:小韓啊,在家休息嗎?
我:是啊。
老闆:別休息啦,來加個班,用上次你寫的kNN,做一個手寫識別系統,訓練集和測試集我都發你郵箱了!週日晚上給我!
我:(What???大週末的,你讓我加班,老子不幹了!)行,保證寫出來!
行了行了,週末不休息了,開工!
這次我們要構建一個手寫識別系統,為了簡單,我們就只識別0-9。需要識別的數字已經用圖形處理軟體,處理成具有相同的色彩和大小:寬高是32畫素×32畫素的黑白影象。儘管採用文字格式儲存影象不能有效地利用記憶體空間,但是為了方便我們的理解,我們還是將影象轉換為文字格式。示例如下:
然後,我們來看一下,使用kNN構造手寫識別系統的步驟:
- 收集資料:提供文字檔案。
- 準備資料:編寫函式classify0(),將影象格式轉換為分類器使用的list格式。
- 分析資料:在Python命令提示符中檢查資料,確保它符合要求。
- 訓練演算法:此步驟不適用於k-近鄰演算法。
- 測試演算法:編寫函式使用提供的部分資料集作為測試樣本,測試樣本與非測試樣本的區別在於測試樣本是已經完成分類的資料,如果預測分類與實際類別不同,則標記為一個錯誤。
- 使用演算法:本例沒有完成此步驟,若你感興趣可以構建完整的應用程式,從影象中提取數字,並完成數字識別,美國的郵件分揀系統就是一個實際執行的類似系統。
2.3.1 準備資料:將影象轉換為測試向量
老闆給的訓練集在目錄trainingDigits中,其中包含了大約2000個例子,每個數字大概有200個樣本。測試集在目錄testDigits中,其中大約900個測試資料。截圖如下:
每個文字檔名稱下劃線前的數字代表這個文字檔案所代表數字。比如說0_8.txt代表的是數字0的第9個樣本(從0開始計數)。
為了使用我們先前編寫好的分類器,我們必須將影象格式化處理為一個向量。我們將一個32×32的二進位制影象矩陣轉換為1×1024的向量。
好了,程式碼走起來!我們繼續在kNN.py中編寫函式img2vector,程式碼如下:
def img2vector(filename):
returnVect = zeros((1, 1024))
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnVect[0, 32 * i + j] = int(lineStr[j])
return returnVect
程式碼很簡單,就是將原來32×32轉換成1×1024,這裡我也就不多說什麼了。大家可以自己去測試一下效果。
2.3.2 使用k-近鄰演算法識別手寫數字
上一節我們已經把資料處理成我們想要的格式了,那麼接下來我們就可以將這些資料丟到分類器裡了。直接來看程式碼:
def handwritingClassTest():
# 1.初始化我們所需要的資料
hwLabels = []
trainingFileList = os.listdir('trainingDigits') # 這裡需要我們提前匯入os模組,listdir可以列出給定目錄下的檔名
m = len(trainingFileList) # 獲得訓練樣本數目
trainingMat = zeros((m, 1024)) # 構造m×1024的矩陣
# 2.迴圈遍歷訓練集中的每個檔案,生成每個數字的向量資訊,儲存在trainingMat中
for i in range(m):
fileNameStr = trainingFileList[i] # 獲得檔名
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0]) # 獲得該檔案所代表的數字
hwLabels.append(classNumStr) # 將檔案所代表的數字其存放在類別標籤中
trainingMat[i, :] = img2vector('trainingDigits/%s' % fileNameStr) # 資料轉換
# 3.遍歷測試資料資料夾,使用kNN進行測試。
testFileList = os.listdir('testDigits')
errorCount = 0.0
mTest = len(testFileList) # 獲得測試樣本數目
for i in range(mTest):
fileNameStr = testFileList[i] # 獲得檔名
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0]) # 獲得該檔案所代表的數字
vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3) # 分類
print('the classifier came back with: %d, the real answer is: %d' % (classifierResult, classNumStr))
if classifierResult != classNumStr:
errorCount += 1.0
print('\nthe total number of errors is: %d' % errorCount)
print('\nthe total error rate is: %f' % (errorCount / float(mTest)))
上面程式碼也不難,每一步的具體含義我都給大家寫在註釋中了,所以我也就不多說了。
依賴於機器速度,載入資料集可能要花費很長時間,然後函式開始依次測試每個檔案,我們直接來看輸出的結果:
我們使用k-近鄰演算法識別手寫數字資料集,錯誤率為1.2%。
改變變數k的值、修改函式handwritingClassTest隨機選取訓練樣本、改變訓練樣本的數目,都會對k-近鄰演算法的錯誤率產生影響,感興趣的話可以改變這些變數值,觀察錯誤率的變化。
但是,我們需要注意的是,實際使用這個演算法時,演算法的執行效率並不高。原因如下:
- 演算法需要為每個測試向量做2000次距離計算,每個距離計算包括了1024個維度浮點運算,總計要執行900次,
- 此外,我們還需要為測試向量準備2MB的儲存空間。
2.4 小結
kNN的理論、實戰,我們就講到這裡了,下面我們來總結一下:
- k-近鄰演算法是分類資料最簡單最有效的演算法,我們通過兩次實戰講述瞭如何使用k-近鄰演算法構造分類器。
- k-近鄰演算法是基於例項的學習,使用演算法時我們必須有接近實際資料的訓練樣本資料。
- k-近鄰演算法必須儲存全部資料集,如果訓練資料集的很大,必須使用大量的儲存空間。此外, 由於必須對資料集中的每個資料計算距離值,實際使用時可能非常耗時。
- k-近鄰演算法的另一個缺陷是它無法給出任何資料的基礎結構資訊,因此我們也無法知曉平均例項樣本和典型例項樣本具有什麼特徵。
好了,k-近鄰演算法我們就講到這裡,因為是最基礎的,所以用了比較多的篇幅,希望大家能夠慢慢看完,對機器學習先有一個感性的認識。
機器學習的路還很長,加油,沖沖衝!!!
最後,還是熟悉的配方!
歡迎大家關注我的公眾號,有什麼問題也可以給我留言哦!