1. 程式人生 > >基於k近鄰(KNN)的手寫數字識別

基於k近鄰(KNN)的手寫數字識別

作者:faaronzheng 轉載請註明出處!

最近再看Machine Learning in Action. k近鄰演算法這一章節提供了不少例子,本著Talk is cheap的原則,我們用手寫數字識別來實際測試一下。 簡單的介紹一下k近鄰演算法(KNN):給定測試樣本,基於某種距離度量找出訓練集中與其最靠近的k個訓練樣本,然後基於這k個“鄰居”的資訊來進行預測。如下圖所示:


x為測試樣本,小黑點是一類樣本,小紅點是另一類樣本。在測試樣本x的周圍畫一個圈,這個圈就是依據某種距離度量畫出的,可以看到我們選擇的是5近鄰。現在我們要做出一個預測,就是這個測試樣本x是屬於小黑點那一類還是小紅點那一類呢?很簡單,我們只要看看選中的近鄰中哪一類樣本多就把這類樣本的標籤賦給測試樣本就可以了。圖中自然就是小黑點,所以我們預測x是小黑點。

正文:

第一步:準備實驗資料。Machine Learning in Action書中的資料使用的是“手寫數字資料集的光學識別”一文中的資料。具體可以參考書中的相關介紹。所有的資料是以Txt形式儲存的,由32行32列的0/1元素組成。下圖就是一個手寫數字0的儲存資料。可以看出,數字所在的位置用1表示,空白的用0表示。

除此之外,為了能識別自己手寫的數字,我們在原來實驗的基礎上新增畫板的功能,使其能採集自己手寫的數字並按照相同的格式儲存下來。如下圖所示,當點選CustomizeTestData後會出現一個畫板,當我們在畫板上寫上數字後,按下ESC鍵儲存圖片並退出,接下來將儲存的圖片處理成我們想要的格式,就可以用演算法對其進行預測了。畫板的實現使用了pygame。


下面是畫板功能的具體實現:

import pygame
from pygame.locals import *
import math
from sys import exit
#向sys模組借一個exit函式用來退出程式
pygame.init()
#初始化pygame,為使用硬體做準備
 
class Brush():
 def __init__(self, screen):
  self.screen = screen
  self.color = (0, 0, 0)
  self.size = 4
  self.drawing = False
  self.last_pos = None
  self.space = 1
  # if style is True, normal solid brush
  # if style is False, png brush
  self.style = False
  # load brush style png
  self.brush = pygame.image.load("brush.png").convert_alpha()
  # set the current brush depends on size
  self.brush_now = self.brush.subsurface((0,0), (1, 1))
 
 def start_draw(self, pos):
  self.drawing = True
  self.last_pos = pos
 def end_draw(self):
  self.drawing = False
 
 def set_brush_style(self, style):
  print "* set brush style to", style
  self.style = style
 def get_brush_style(self):
  return self.style
 
 def get_current_brush(self):
  return self.brush_now
 
 def set_size(self, size):
  if size < 0.5: size = 0.5
  elif size > 32: size = 32
  print "* set brush size to", size
  self.size = size
  self.brush_now = self.brush.subsurface((0,0), (size*2, size*2))
 def get_size(self):
  return self.size
 
 def set_color(self, color):
  self.color = color
  for i in xrange(self.brush.get_width()):
   for j in xrange(self.brush.get_height()):
    self.brush.set_at((i, j),
      color + (self.brush.get_at((i, j)).a,))
 def get_color(self):
  return self.color
 
 def draw(self, pos):
  if self.drawing:
   for p in self._get_points(pos):
    # draw eveypoint between them
    if self.style == False:
     pygame.draw.circle(self.screen, self.color, p, self.size)
    else:
     self.screen.blit(self.brush_now, p)
 
   self.last_pos = pos
 
 def _get_points(self, pos):
  """ Get all points between last_point ~ now_point. """
  points = [ (self.last_pos[0], self.last_pos[1]) ]
  len_x = pos[0] - self.last_pos[0]
  len_y = pos[1] - self.last_pos[1]
  length = math.sqrt(len_x ** 2 + len_y ** 2)
  step_x = len_x / length
  step_y = len_y / length
  for i in xrange(int(length)):
   points.append(
     (points[-1][0] + step_x, points[-1][1] + step_y))
  points = map(lambda x:(int(0.5+x[0]), int(0.5+x[1])), points)
  # return light-weight, uniq integer point list
  return list(set(points))
 
class Menu():
 def __init__(self, screen):
  self.screen = screen
  self.brush = None

 def set_brush(self, brush):
  self.brush = brush

 
class Painter():
 def __init__(self):
  self.screen = pygame.display.set_mode((100, 100))
 # self.menu = pygame.display.set_mode((80, 600))
  pygame.display.set_caption("Painter")
  self.clock = pygame.time.Clock()
  self.brush = Brush(self.screen)
  self.menu = Menu(self.screen)
  self.menu.set_brush(self.brush)
 
 def run(self):
  self.screen.fill((255, 255, 255))
  while True:
   # max fps limit
   self.clock.tick(30)
   for event in pygame.event.get():
    if event.type == QUIT:
        pygame.quit()
     #   break
    elif event.type == KEYDOWN:
     # press esc to clear screen
     if event.key == K_ESCAPE:
      fname = "test.png"
      pygame.image.save(self.screen, fname)    
      pygame.quit()
      #break
    elif event.type == MOUSEBUTTONDOWN:
     # <= 74, coarse judge here can save much time
     if ((event.pos)[0] <= 74 and
       self.menu.click_button(event.pos)):
      # if not click on a functional button, do drawing
      pass
     else:
      self.brush.start_draw(event.pos)
    elif event.type == MOUSEMOTION:
     self.brush.draw(event.pos)
    elif event.type == MOUSEBUTTONUP:
     self.brush.end_draw()
    self.menu.draw()
    pygame.display.update()

KNN演算法--KNN的關鍵在我看來是距離度量的選擇。不同的距離度量會對最終的結果產生比較大的影響。首先將手寫數字變化為一個一維的向量,通過計算測試樣例(向量)和每個訓練樣本(向量)之間的距離然後進行排序。最後選最近的k個進行投票產生對測試樣例的預測。

import pygame
from numpy import *
import operator
from os import listdir
from Board import *
import Tkinter
import tkFileDialog
import tkMessageBox
import Image  
from KNN import dot
pygame.init()


def classify0(inX, dataSet, labels, k):           #k控制選取最近的k個近鄰然後投票
    dataSetSize = dataSet.shape[0]
    #計算歐式距離(其實比較的是兩個向量之間的距離)
    diffMat = tile(inX, (dataSetSize,1)) - dataSet
    sqDiffMat = diffMat**2
    sqDistances = sqDiffMat.sum(axis=1)
    distances = sqDistances**0.5
    sortedDistIndicies = distances.argsort()     
    classCount={}          
    #投票
    for i in range(k):
        voteIlabel = labels[sortedDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]
def classify1(inX,dataSet,labels, k):
 dataSetSize = dataSet.shape[0]
 diffMat = tile(inX, (dataSetSize,1)) - dataSet
 diffMatT=(diffMat.T)
 sqDiffMat = dot(diffMat,diffMat.T)
 distances = sqrt(sqDiffMat)   
 sortedDistIndicies=distances.argsort() 
 classCount={}         
#投票
 for i in range(k):
    voteIlabel = labels[sortedDistIndicies[i]]
    classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
 sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
 return sortedClassCount[0][0]
# 將檔案轉化為向量
def img2vector(filename):
    returnVect = zeros((1,1024))
    fr = open(filename)
    for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
            returnVect[0,32*i+j] = int(lineStr[j])
    return returnVect

def handwritingClassTest(TrainDataPath):
    hwLabels = []
    trainingFileList = listdir(TrainDataPath)           #load the training set
    m = len(trainingFileList)
    trainingMat = zeros((m,1024))
    for i in range(m):
        fileNameStr = trainingFileList[i]
        fileStr = fileNameStr.split('.')[0]     #take off .txt
        classNumStr = int(fileStr.split('_')[0])
        hwLabels.append(classNumStr)
        trainingMat[i,:] = img2vector(TrainDataPath+'/%s' % fileNameStr)
    testFileList = listdir('C:/Users/HP/Desktop/MLiA_SourceCode/machinelearninginaction/Ch02/testDigits')        #iterate through the test set
    errorCount = 0.0
    mTest = len(testFileList)
    for i in range(mTest):
        fileNameStr = testFileList[i]
        fileStr = fileNameStr.split('.')[0]     #take off .txt
        classNumStr = int(fileStr.split('_')[0])
        vectorUnderTest = img2vector('C:/Users/HP/Desktop/MLiA_SourceCode/machinelearninginaction/Ch02/testDigits/%s' % fileNameStr)
        classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
        print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr)
        if (classifierResult != classNumStr): errorCount += 1.0
    print "\nthe total number of errors is: %d" % errorCount
    print "\nthe total error rate is: %f" % (errorCount/float(mTest))

top = Tkinter.Tk()

def TrainDataCallBack():
    TrainDataPath=tkFileDialog.askdirectory()
    handwritingClassTest(TrainDataPath)

def CustomizeTestDataCallBack():
    board = Painter()
    board.run()
    
def TestingCustomizeTestDataCallBack():
    ResizePic()
    TransformArray()
TrainDataButton = Tkinter.Button(top, text ="TrainData", command = TrainDataCallBack)
CustomizeTestDataButton = Tkinter.Button(top, text ="CustomizeTestData", command = CustomizeTestDataCallBack)
TestingButton = Tkinter.Button(top, text ="TestingCustomizeTestData", command = TestingCustomizeTestDataCallBack)

def ResizePic():
    im = Image.open("test.png")  
    w,h = im.size  
    im_ss = im.resize((int(32), int(32)))  
    im_ss.save("test.png")  

def TransformArray():
    TestArray = zeros((1,1024))
    im = Image.open("test.png")  
    width,height = im.size  
    for h in range(0, height):  
      for w in range(0, width):  
        pixel = im.getpixel((w, h))      
        if pixel!=(255,255,255):
            TestArray[0,32*h+w]=int(1)
    handwritingTesting(TestArray)


def handwritingTesting(TestArray):
  #  TrainDataPath=tkFileDialog.askdirectory()
    TrainDataPath="C:/Users/HP/Desktop/MLiA_SourceCode/machinelearninginaction/Ch02/trainingDigits"
    hwLabels = []
    trainingFileList = listdir(TrainDataPath)           #load the training set
    m = len(trainingFileList)
    trainingMat = zeros((m,1024))
    for i in range(m):
        fileNameStr = trainingFileList[i]
        fileStr = fileNameStr.split('.')[0]     #take off .txt
        classNumStr = int(fileStr.split('_')[0])
        hwLabels.append(classNumStr)
        trainingMat[i,:] = img2vector(TrainDataPath+'/%s' % fileNameStr)      
    classifierResult = classify0(TestArray, trainingMat, hwLabels, 100)
    classifierResult1 = classify1(TestArray, trainingMat, hwLabels, 100)
    print "the classifier came back with: %d"  %  classifierResult
    print "the classifier came back with: %d"  %  classifierResult1
       
TrainDataButton.pack()
CustomizeTestDataButton.pack()
TestingButton.pack()
top.mainloop()
這裡面我們採用了很笨的方法將測試樣本與所有訓練樣本進行比較,更有效的方法是採用KD樹。另外k的取值在這裡也是固定的,更好的方法是在一個區間內網格搜尋~

未完待續。。。

相關推薦

基於k近鄰KNN數字識別

作者:faaronzheng 轉載請註明出處! 最近再看Machine Learning in Action. k近鄰演算法這一章節提供了不少例子,本著Talk is cheap的原則,我們用手寫數字識別來實際測試一下。 簡單的介紹一下k近鄰演算法(KNN):給定測試樣本

《機器學習實戰》第二章:k-近鄰演算法3數字識別

這是k-近鄰演算法的最後一個例子——手寫數字識別! 怎樣?是不是聽起來很高大上? 呵呵。然而這跟影象識別沒有半毛錢的關係 因為每個資料樣本並不是手寫數字的圖片,而是有由0和1組成的文字檔案,就像這樣:         嗯,這個資料集中的每一個樣本用圖形軟體處理過,變成了寬高

Tensorflow筆記——神經網絡圖像識別數字識別

bsp 14. image 分享 img 筆記 nbsp 數字識別 com Tensorflow筆記——神經網絡圖像識別(五)手寫數字識別

TensorFlow神經網路數字識別

內容來自mooc人工智慧實踐第五講 一、MNIST資料集一些用到的基礎函式語法 ############ warm up ! ############ # 匯入imput_data模組 from tensorflow.examples.tutorials.mnist import

基於TensorFlow的K近鄰KNN分類器實現——以MNIST為例

KNN分類原理 TF的KNN程式碼 def load_mnist_data(filename,isbatch=0,train_nums=1000,test_nums=200): from tensorflow.examples.

K-近鄰KNN算法

第三章 不同 bin 挖掘 特性 訓練 屬於 博客 建立   K-近鄰算法(K-NN)   鄰近算法,或者說K最近鄰(kNN,k-NearestNeighbor)分類算法是數據挖掘分類技術中最簡單的方法之一。所謂K最近鄰,就是k個最近的鄰居的意思,說的是每個樣本都可以用它最

機器學習之K-近鄰KNN算法

實戰 http created 以及 dex mda 問題 可以轉化 占比 一 . K-近鄰算法(KNN)概述 最簡單最初級的分類器是將全部的訓練數據所對應的類別都記錄下來,當測試對象的屬性和某個訓練對象的屬性完全匹配時,便可以對其進行分類。但是怎麽可能所有測

機器學習——K-近鄰KNN演算法

本文轉載自:https://www.cnblogs.com/ybjourney/p/4702562.html 一 . K-近鄰演算法(KNN)概述 最簡單最初級的分類器是將全部的訓練資料所對應的類別都記錄下來,當測試物件的屬性和某個訓練物件的屬性完全匹配時,便可以對其進

K近鄰KNN演算法的學習

K近鄰演算法 首先應該明白K近鄰演算法是一種分類演算法,這裡不作實際問題抽象成數學模型的解釋,只在數學方面來說明。 背景 已知一系列的資料點被分成了幾個類,現在有一個新的未知點,那麼該把這個未知點歸為哪一類呢?這就是K近鄰演算法來解決的問題。 演算法 在未知點附近

K近鄰KNN演算法、KD樹及其python實現

1、k近鄰演算法 1.1 KNN基本思想 k近鄰法是基本且簡單的分類與迴歸方法,即對於輸入例項,依據給定的距離度量方式(歐式距離),以及選擇合適的k值(交叉驗證),在樣本集中找到最近鄰新例項的k個樣例,通過k個最近鄰樣例的類別表決出新例項的類別(多數表決)。

AI產品經理必懂演算法:k-近鄰KNN演算法

我們之所以要了解演算法,不僅僅有利於和演算法同學的溝通,更能深入的理解人工智慧為產品賦能的過程,只有將這個過程瞭解透徹,才能清晰明確的把握產品的方向,挖掘產品的亮點。 那麼,今天我們就從一個最為簡單、易懂的“k-近鄰(KNN)演算法”聊起,KNN屬於監督學習演算法,即可以用於分類,也可以用於迴歸,後續還會逐

機器學習——K-近鄰KNN演算法

#coding:utf-8 from numpy import * import operator from collections import Counter import matplotlib import matplotlib.pyplot as plt ###匯入特徵資料

使用K近鄰KNN對鳶尾花分類

KNN演算法的精髓在於近朱者赤近墨者黑,所以距離通過歐氏距離或者夾角餘弦來計算。 他的主要計算步驟為: 1.算距離:給定測試物件,計算它與訓練集中的每個物件的距離 2.尋找鄰居:圈定距離最近的K個訓練物件,作為測試物件的近鄰。 3.做分類:根據這K個近鄰歸屬的主要類別,來

K近鄰KNN的Python實現

KNN KNN almost has no training process. At the very beginning, we can classify the test samples. K used to be lower than sqrt(#

python機器學習-k近鄰KNN演算法例項

機器學習-k近鄰(KNN) 本篇主要是自己複習和總結機器學習演算法中最基礎入門的——k近鄰(KNN)演算法,內容由網上百度與摘抄唐宇迪老師的講義。 k近鄰介紹 ——K最近鄰(k-Nearest Neighbor,KNN),k近鄰演算法可以應用於分類場景與迴歸場

KNN數字識別

用 KNN 做手寫數字識別 目錄 用 KNN 做手寫數字識別 1. KNN的原理 2. KNN實現手寫數字識別過程 作為一個小白,寫此文章主要是為了自己記錄,方便回過頭來查詢! 本文主要參考ApacheCN(專注於優秀專案維護的開源組織)中MachineL

機器學習實戰——KNN演算法數字識別

資料來源 我們的文字是形如這樣的,每個數字都有很多txt檔案,TXT裡面是01數字,表示手寫數字的灰度圖。 現在我們要用knn演算法實現數字識別。 資料處理 每個txt檔案都是32*32的0,1矩陣,如果要使用knn,那麼還得考慮行列關係,如果能把它拉開,只有一行,就可以不必考慮數字

KNN / SVM 數字識別-PCA降維

一.問題分析採用機器學習演算法對usps和mnist兩個資料集完成手寫數字識別任務。1.1.資料集介紹MNIST MNIST 資料集來自美國國家標準與技術研究所, National Institute of Standards and Technology (NIST)。訓練

基於opencv3.4和SVM的數字識別

  本文將使用opencv3.4和SVM識別手寫數字,開發環境為vs2013和C++。 資料集   opencv安裝資料夾的 “samples/data” 下預置了一張手寫數字的圖片,其包含了5000個 0至9 的手寫數字,每個數字

【好玩的計算機視覺】KNN演算法數字識別

OCR應用非常廣泛,而且有許多方法,今天用KNN演算法實現簡單的0-9手寫數字識別。本程式使用OpenCV 3.0和Python 3。 KNN演算法是K近鄰分類演算法,屬於機器學習中的監督學習,需要一定量的帶標籤的輸入樣本資料進行“訓練”,然後就可以識別。我給“訓練”打引