1. 程式人生 > 實用技巧 >k-近鄰演算法

k-近鄰演算法

from numpy import *
import operator

def classify0(inX, dataSet, labels, k):
    dataSetSize = dataSet.shape[0]#獲取資料集的行數

classify0()函式有4個引數:inX:用於分類的輸入向量;dataset:輸入的訓練集;labels:標籤向量;k:最近鄰數。其中標籤向量的元素數目和矩陣dataset的行數相同。計算兩點間的距離公式為:

d =√(Ax- Bx)2 + (Ay - By)2

寫成程式碼如下:

#計算dataSet中元素到原點的距離   
    diffMat = tile(inX, (dataSetSize, 1)) - dataSet 
    
#tile函式將inX填充至行數與dataset相同,再減去dataset,相當於dataset * (-1) sqDiffMat = diffMat ** 2 #diffMat中所有元素求平方 sqDistances = sqDiffMat.sum(axis = 1) distances = sqDistances ** 0.5

計算完所有點之間的距離後,將資料按照從小到大的次序排列:

sortedDistIndicies = distances.argsort()

確定前k個距離最小的元素所在的主要分類,輸入k總是正整數;然後將classcount字典分解為元組列表,使用itemgetter方法,按照第二個元素的次序對元組進行從大到小排序,最後返回發生頻率最高的元素標籤

classCount = {}
    for i in range(k):
        voteIlabel = labels[sortedDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
    sortedClassCount = sorted(array(classCount).iteritems(), 
                              key = operator.itemgetter(1), reverse = True)
    return
sortedClassCount[0][0]