1. 程式人生 > >2. KNN和KdTree算法實現

2. KNN和KdTree算法實現

nload bug bre pen 操作性 nod 系列 特點 最近鄰

1. K近鄰算法(KNN)

2. KNN和KdTree算法實現

1. 前言

KNN一直是一個機器學習入門需要接觸的第一個算法,它有著簡單,易懂,可操作性強的一些特點。今天我久帶領大家先看看sklearn中KNN的使用,在帶領大家實現出自己的KNN算法。

2. KNN在sklearn中的使用

knn在sklearn中是放在sklearn.neighbors的包中的,我們今天主要介紹KNeighborsClassifier的分類器。

KNeighborsClassifier的主要參數是:

參數 意義
n_neighbors K值的選擇與樣本分布有關,一般選擇一個較小的K值,可以通過交叉驗證來選擇一個比較優的K值,默認值是5
weights ‘uniform’是每個點權重一樣,‘distance’則權重和距離成反比例,即距離預測目標更近的近鄰具有更高的權重
algorithm ‘brute’對應第一種蠻力實現,‘kd_tree’對應第二種KD樹實現,‘ball_tree’對應第三種的球樹實現, ‘auto’則會在上面三種算法中做權衡,選擇一個擬合最好的最優算法。
leaf_size 這個值控制了使用KD樹或者球樹時, 停止建子樹的葉子節點數量的閾值。
metric K近鄰法和限定半徑最近鄰法類可以使用的距離度量較多,一般來說默認的歐式距離(即p=2的閔可夫斯基距離)就可以滿足我們的需求。
p p是使用距離度量參數 metric 附屬參數,只用於閔可夫斯基距離和帶權重閔可夫斯基距離中p值的選擇,p=1為曼哈頓距離, p=2為歐式距離。默認為2

我個人認為這些個參數,比較重要的應該屬n_neighbors、weights了,其他默認的也都沒太大問題。

3. KNN基礎版實現

直接看代碼如下,完整代碼GitHub:

def fit(self, X_train, y_train):
    self.X_train = X_train
    self.y_train = y_train

def predict(self, X):
    # 取出n個點
    knn_list = []
    for i in range(self.n):
        dist = np.linalg.norm(X - self.X_train[i], ord=self.p)
        knn_list.append((dist, self.y_train[i]))

    for i in range(self.n, len(self.X_train)):
        max_index = knn_list.index(max(knn_list, key=lambda x: x[0]))
        dist = np.linalg.norm(X - self.X_train[i], ord=self.p)
        if knn_list[max_index][0] > dist:
            knn_list[max_index] = (dist, self.y_train[i])

    # 統計
    knn = [k[-1] for k in knn_list]
    return Counter(knn).most_common()[0][0]

我的接口設計都是按照sklearn的樣子設計的,fit方法其實主要用來接收參數了,沒有進行任何的處理。所有的操作都在predict中,著就會導致,我們對每個點預測的時候,時間消耗比較大。這個基礎版本大家看看就好,我想大家自己去寫,肯定也沒問題的。

4. KdTree版本實現

kd樹算法包括三步,第一步是建樹,第二部是搜索最近鄰,最後一步是預測。

4.1 構建kd樹

kd樹是一種對n維空間的實例點進行存儲,以便對其進行快速檢索的樹形結構。kd樹是二叉樹,構造kd樹相當於不斷的用垂直於坐標軸的超平面將n維空間進行劃分,構成一系列的n維超矩陣區域。

下面的流程圖更加清晰的描述了kd樹的構建過程:

技術分享圖片

kdtree樹的生成代碼:

# 建立kdtree
def create(self, dataSet, label, depth=0):
    if len(dataSet) > 0:
        m, n = np.shape(dataSet)
        self.n = n
        axis = depth % self.n
        mid = int(m / 2)
        dataSetcopy = sorted(dataSet, key=lambda x: x[axis])
        node = Node(dataSetcopy[mid], label[mid], depth)
        if depth == 0:
            self.KdTree = node
        node.lchild = self.create(dataSetcopy[:mid], label, depth+1)
        node.rchild = self.create(dataSetcopy[mid+1:], label, depth+1)
        return node
    return None

4.2 kd樹搜索最近鄰和預測

當我們生成kd樹以後,就可以去預測測試集裏面的樣本目標點了。預測的過程如下:

  1. 對於一個目標點,我們首先在kd樹裏面找到包含目標點的葉子節點。以目標點為圓心,以目標點到葉子節點樣本實例的距離為半徑,得到一個超球體,最近鄰的點一定在這個超球體內部。
  2. 然後返回葉子節點的父節點,檢查另一個子節點包含的超矩形體是否和超球體相交,如果相交就到這個子節點尋找是否有更加近的近鄰,有的話就更新最近鄰,並且更新超球體。如果不相交那就簡單了,我們直接返回父節點的父節點,在另一個子樹繼續搜索最近鄰。
  3. 當回溯到根節點時,算法結束,此時保存的最近鄰節點就是最終的最近鄰。
    kdtree樹的搜索代碼:
# 搜索kdtree的前count個近的點
def search(self, x, count = 1):
    nearest = []
    for i in range(count):
        nearest.append([-1, None])
    # 初始化n個點,nearest是按照距離遞減的方式
    self.nearest = np.array(nearest)

    def recurve(node):
        if node is not None:
            # 計算當前點的維度axis
            axis = node.depth % self.n
            # 計算測試點和當前點在axis維度上的差
            daxis = x[axis] - node.data[axis]
            # 如果小於進左子樹,大於進右子樹
            if daxis < 0:
                recurve(node.lchild)
            else:
                recurve(node.rchild)
            # 計算預測點x到當前點的距離dist
            dist = np.sqrt(np.sum(np.square(x - node.data)))
            for i, d in enumerate(self.nearest):
                # 如果有比現在最近的n個點更近的點,更新最近的點
                if d[0] < 0 or dist < d[0]:
                    # 插入第i個位置的點
                    self.nearest = np.insert(self.nearest, i, [dist, node], axis=0)
                    # 刪除最後一個多出來的點
                    self.nearest = self.nearest[:-1]
                    break

            # 統計距離為-1的個數n
            n = list(self.nearest[:, 0]).count(-1)
            '''
            self.nearest[-n-1, 0]是當前nearest中已經有的最近點中,距離最大的點。
            self.nearest[-n-1, 0] > abs(daxis)代表以x為圓心,self.nearest[-n-1, 0]為半徑的圓與axis
            相交,說明在左右子樹裏面有比self.nearest[-n-1, 0]更近的點
            '''
            if self.nearest[-n-1, 0] > abs(daxis):
                if daxis < 0:
                    recurve(node.rchild)
                else:
                    recurve(node.lchild)

    recurve(self.KdTree)

    # nodeList是最近n個點的
    nodeList = self.nearest[:, 1]

    # knn是n個點的標簽
    knn = [node.label for node in nodeList]
    return self.nearest[:, 1], Counter(knn).most_common()[0][0]

這段代碼其實比較好的實現了上面搜索的思想。如果讀者對遞歸的過程想不太清楚,可以畫下圖,或者debug下我完整的代碼GitHub

5. 總結

本文實現了KNN的基礎版和KdTree版本,讀者可以去嘗試下ballTree的版本,理論上效率比KdTree還要好一些。

2. KNN和KdTree算法實現