K-近鄰演算法之手寫數字識別系統
阿新 • • 發佈:2018-12-31
定義將影象轉換為向量函式
# 匯入程式所需要的模組
import numpy as np
import operator
from os import listdir
讀取檔案
def img2vector(filename): returnVect = np.zeros((1, 1024)) # 儲存圖片畫素的向量維度是1x1024 fr = open(filename) for i in range(32): lineStr = fr.readline() for j in range(32): returnVect[0, 32*i+j] = int(lineStr[j]) # 圖片尺寸是32x32,將其依次放入向量returnVect中 return returnVect
定義 k 近鄰演算法
def classify0(inX, dataSet, labels, k): # inX是測試集,dataSet是訓練集,lebels是訓練樣本標籤,k是取的最近鄰個數 dataSetSize = dataSet.shape[0] # 訓練樣本個數 diffMat = np.tile(inX, (dataSetSize, 1)) - dataSet # np.tile: 重複n次 sqDiffMat = diffMat**2 sqDistances = sqDiffMat.sum(axis=1) distances = sqDistances**0.5 # distance是inX與dataSet的歐氏距離 sortedDistIndicies = distances.argsort() # 返回排序從小到達的索引位置 classCount = {} # 字典儲存k近鄰不同label出現的次數 for i in range(k): voteIlabel = labels[sortedDistIndicies[i]] classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1 # 對應label加1,classCount中若無此key,則預設為0 sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True) # operator.itemgetter 獲取物件的哪個維度的資料 return sortedClassCount[0][0] # 返回k近鄰中所屬類別最多的哪一類
定義手寫數字識別系統函式
def handwritingClassTest(): # 訓練樣本 hwLabels = [] trainingFileList = listdir('./digits/trainingDigits') #匯入訓練集 m = len(trainingFileList) trainingMat = np.zeros((m, 1024)) for i in range(m): fileNameStr = trainingFileList[i] # fileNameStr 得到的是每個檔名稱,例如"0_0.txt" fileStr = fileNameStr.split('.')[0] #去掉“.txt”,剩下“0_0” classNumStr = int(fileStr.split('_')[0]) # 按下劃線‘_' 劃分“0_0”,取第一個元素為類別標籤 hwLabels.append(classNumStr) trainingMat[i, :] = img2vector('./digits/trainingDigits/%s' % fileNameStr) # 測試樣本 testFileList = listdir('./digits/testDigits') #iterate through the test set errorCount = 0.0 mTest = len(testFileList) for i in range(mTest): fileNameStr = testFileList[i] # fileNameStr 得到的是每個檔名稱,例如"0_0.txt" fileStr = fileNameStr.split('.')[0] #去掉“.txt”,剩下“0_0” classNumStr = int(fileStr.split('_')[0]) # 按下劃線‘_' 劃分“0_0”,取第一個元素為類別標籤 vectorUnderTest = img2vector('./digits/testDigits/%s' % fileNameStr) classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3) # 呼叫knn函式 print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr)) if (classifierResult != classNumStr): errorCount += 1.0 print("\nthe total number of errors is: %d" % errorCount) print("\nthe total error rate is: %f" % (errorCount/float(mTest)))
執行例項函式
img2vector('D:/360安全瀏覽器下載/MachineLearningInAction-Camp-master/Week1/Reference Code/digits/testDigits/0_13.txt')
結果為:
array([[0., 0., 0., ..., 0., 0., 0.]])
主函式為:
handwritingClassTest()
結果如下: