Python實現KNN演算法手寫識別數字
阿新 • • 發佈:2019-02-15
本文實現用KNN演算法實現手寫識別數字功能。
語言:Python
訓練材料:手寫數字素材32*32畫素
from numpy import *
import os
from os import listdir
import operator
#將檔案32*32轉成1*1024
def img2vector(filename):
vect=zeros((1,1024))
f=open(filename)
for i in range(32):
line=f.readline()
for j in range(32):
vect[0 ,32*i+j]=int(line[j])
return vect
def dict2list(dic:dict):
#''' 將字典轉化為列表 '''
keys = dic.keys()
vals = dic.values()
lst = [(key, val) for key, val in zip(keys, vals)]#zip是一個可迭代物件
return lst
#inputvector:輸入的用於測試的向量
#trainDataSet:訓練的樣本集
#labels:標籤
#k:k鄰近的個數
def knntest(inputvector,trainDataSet,labels,k) :
datasetsize=trainDataSet.shape[0]
#tile(a,[2,3]) ([a a a],[a,a,a])用第一個引數來構造
#這裡用輸入向量來構造一個1024行 1列的矩陣,剛好和訓練矩陣同樣大小
diffmat=tile(inputvector,(datasetsize,1))-trainDataSet
#求平方和
#每個元素都平方
sqdiffmat=diffmat**2
#按行求和
sqdistance=sqdiffmat.sum(axis=1)
#平方根,得到的是一個一維的矩陣
distance=sqdistance**0.5
#按照從低到高排序
#argsort函式排列後得到的是按下標進行排列的矩陣,
#在原先distance中的下標按距離最近排列 argsort函式返回的是陣列值從小到大的索引值
sortdistance=distance.argsort()
classcout={}#用來儲存key(標籤)value(標籤出現的次數,選取次數最大的前幾個數,找到其標籤)
#依次取出最近的樣本資料
for i in range(k):
#記樣本的類別
votelabel=labels[sortdistance[i]]
#統計每個標籤的次數
classcout[votelabel]=classcout.get(votelabel,0)+1#獲取votelabel鍵對應的值,無返回預設
#print("*************")
#print(classcout)
#classcout.iteritems()在Python3中取消了,key=lambda x:x[0](按第0個元素排序)字典排序,按照value來排序,返回鍵
sortclasscount=sorted(dict2list(classcout),key=operator.itemgetter(1),reverse=True)
#返回出現頻次最高的類別
return sortclasscount[0][0]
#手寫識別
def handwritingClassTest():
print(os.getcwd())
#將訓練資料儲存到一個矩陣中1024維,並存儲對應的標籤
handlabel=[]
trainName=listdir(r'digits\trainingDigits')
trainNum=len(trainName)
trainNumpy = zeros((trainNum,1024))
#print("trainNum=%d"%trainNum)
#對檔名進行分析,訓練文字對應的標籤
for i in range(trainNum):
filename=trainName[i]#檔名
filestr=filename.split('.')[0]#不帶字尾的檔名
filelabel=int(filestr.split('_')[0])#檔案的標籤
#將標籤新增至handlabel中
handlabel.append(filelabel)
trainNumpy[i,:]=img2vector(r'digits\trainingDigits\%s'%filename)#轉成1024
#print(handlabel[:20])
testfilelist=listdir(r'digits\testDigits')
errornum=0
testnum=len(testfilelist)
errfile=[]
#將每一個測試樣本放入訓練集中使用KNN進行測試
for i in range(testnum):
testfilename=testfilelist[i]
testfilestr=testfilename.split('.')[0]
testfilelabel=int(testfilestr.split('_')[0])#實際的數字標籤
#將測試樣本1024
testvector=img2vector(r'digits\testDigits\%s'%testfilename)
#進行測試
#print("-----------")
result=knntest(testvector,trainNumpy,handlabel,3)
print("test value is %d, real value is %d"%(result,testfilelabel))
if(result!=testfilelabel):
errornum+=1
errfile.append(testfilename)
print("the num of error is %d"%errornum)
print("the right rate of test is %f "%(1-errornum/float(testnum)))
print("the error of file are ")
count=0
for i in range(len(errfile)):
if(count==9):
print()
print(errfile[i]+' ',end="")
count+=1
def main():
#path=os.getcwd()
handwritingClassTest()
if __name__=='__main__':
main();