1. 程式人生 > 其它 >《統計學習方法》第3章習題

《統計學習方法》第3章習題

習題3.1

習題3.2

根據例 3.2 構造的 kd 樹,可知最近鄰點為 \((2,3)^T\)

習題3.3

k 近鄰法主要需要構造相應的 kd 樹。這裡用 Python 實現 kd 樹的構造與搜尋

import heapq
import numpy as np

class KDNode:
    def __init__(self, data, axis=0, left=None, right=None):
        self.data = data
        self.axis = axis
        self.left = left
        self.right = right

class KDTree:
    def __init__(self, data):
        self.raw_data = data
        self.k = data.shape[1]
    
    def construct(self):
        data = self.raw_data
        self.root = self._insert_node(data, 0)
    
    def search(self, x, near_k=1, p=2):
        self.knn = [(-np.inf, None)]*near_k
        self._visit(self.root, x, p)
        self.knn = np.array([i[1].data for i in heapq.nlargest(near_k, self.knn)])
        return self.knn
        
    def pre_order_traverse(self, node):
        print(node.data)
        if node.left:
            self.pre_order_traverse(node.left)
        if node.right:
            self.pre_order_traverse(node.right)
        
    def _insert_node(self, data, depth=0):
        if len(data) == 0:
            return None
        axis = depth % self.k
        data = sorted(data, key = lambda x: x[axis])
        middle = len(data) // 2
        return KDNode(
          data[middle], 
          axis, 
          self._insert_node(data[:middle], depth+1), 
          self._insert_node(data[middle+1:], depth+1)
        )
    
    def _visit(self, node, x, p=2):
        if node is not None:
            dis = x[node.axis] - node.data[node.axis]
            self._visit(node.left if dis < 0 else node.right, x, p)
            curr_dis = np.linalg.norm(x-node.data, p)
            heapq.heappushpop(self.knn, (-curr_dis, node))
            if -(self.knn[0][0]) > abs(dis):
                self._visit(node.right if dis < 0 else node.left, x, p)

if __name__ == "__main__":
    data = np.array([
        [2,3],
        [5,4],
        [9,6],
        [4,7],
        [8,1],
        [7,2]
    ])

    tree = KDTree(data)
    tree.construct()
    print(tree.search(np.array([3, 4.5]), 2))

通過呼叫 KDTree 的 search 方法即可實現查詢 x 的 k 近鄰。 結果為 \([(2,3)^T, (5,4)^T]\)