k-近鄰演算法
阿新 • • 發佈:2020-08-10
from numpy import * import operator def classify0(inX, dataSet, labels, k): dataSetSize = dataSet.shape[0]#獲取資料集的行數
classify0()函式有4個引數:inX:用於分類的輸入向量;dataset:輸入的訓練集;labels:標籤向量;k:最近鄰數。其中標籤向量的元素數目和矩陣dataset的行數相同。計算兩點間的距離公式為:
d =√(Ax- Bx)2 + (Ay - By)2
寫成程式碼如下:
#計算dataSet中元素到原點的距離 diffMat = tile(inX, (dataSetSize, 1)) - dataSet#tile函式將inX填充至行數與dataset相同,再減去dataset,相當於dataset * (-1) sqDiffMat = diffMat ** 2 #diffMat中所有元素求平方 sqDistances = sqDiffMat.sum(axis = 1) distances = sqDistances ** 0.5
計算完所有點之間的距離後,將資料按照從小到大的次序排列:
sortedDistIndicies = distances.argsort()
確定前k個距離最小的元素所在的主要分類,輸入k總是正整數;然後將classcount字典分解為元組列表,使用itemgetter方法,按照第二個元素的次序對元組進行從大到小排序,最後返回發生頻率最高的元素標籤
classCount = {} for i in range(k): voteIlabel = labels[sortedDistIndicies[i]] classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 sortedClassCount = sorted(array(classCount).iteritems(), key = operator.itemgetter(1), reverse = True) returnsortedClassCount[0][0]