機器學習實戰:knn手寫數字
阿新 • • 發佈:2019-01-11
"""
@author: lishihang
@software: PyCharm
@file: handwritten.py
@time: 2018/11/26 16:18
"""
import numpy as np
import matplotlib.pyplot as plt
import os
import glob
from tqdm import tqdm
def img2vec(filename):
f = open(filename)
data = [list(map(int, i.strip())) for i in f.readlines()]
data = np.array(data)
# plt.imshow(data.reshape(32,-1),cmap='gray')
# plt.show()
data = data.reshape(-1)
return data.tolist()
def readDir(dir):
assert dir[-1] == '/'
fs = glob.glob(dir + "*.txt")
labels = []
xs = []
for f in fs:
labels. append(os.path.split(f)[1].split('_')[0])
xs.append(img2vec(f))
labels = np.array(labels)
xs = np.array(xs)
# print(xs.shape)
# print(labels.shape)
return xs, labels
def knnClaffify(testItem,trainX,trainY,k):
"""
knn分類演算法,單條資料測試
:param testItem: 測試的單條資料
:param trainX: 訓練集特徵
:param trainY: 訓練集標籤
:param k: 鄰居個數
:return: 分類類別
"""
distances=np.sqrt(np.sum((trainX-testItem)**2,axis=1))
ind=np.argsort(distances)
classCount={}
for i in range(k):
vote=trainY[ind[i]]
classCount[vote]=classCount.get(vote,0)+1
classCount=sorted(classCount.items(),key=lambda x:x[0])
return classCount[0][0]
def knnTest():
"""
測試演算法
:return:
"""
x, y = readDir("trainingDigits/")
x_test,y_test= readDir("testDigits/")
print("訓練集:{},測試集:{}".format(len(y),len(y_test)))
trueCount=0
for x_item,y_item in tqdm(list(zip(x_test,y_test))):
result = knnClaffify(x_item, x,y,k=5)
trueCount+=(y_item==result)
print("正確率:{}({}/{})".format(trueCount/len(y_test),trueCount,len(y_test)))
if __name__ == '__main__':
# img2vec("testDigits/8_80.txt")
knnTest()