《統計學習方法》第3章習題
阿新 • • 發佈:2021-06-24
習題3.1
略
習題3.2
根據例 3.2 構造的 kd 樹,可知最近鄰點為 \((2,3)^T\)
習題3.3
k 近鄰法主要需要構造相應的 kd 樹。這裡用 Python 實現 kd 樹的構造與搜尋
import heapq import numpy as np class KDNode: def __init__(self, data, axis=0, left=None, right=None): self.data = data self.axis = axis self.left = left self.right = right class KDTree: def __init__(self, data): self.raw_data = data self.k = data.shape[1] def construct(self): data = self.raw_data self.root = self._insert_node(data, 0) def search(self, x, near_k=1, p=2): self.knn = [(-np.inf, None)]*near_k self._visit(self.root, x, p) self.knn = np.array([i[1].data for i in heapq.nlargest(near_k, self.knn)]) return self.knn def pre_order_traverse(self, node): print(node.data) if node.left: self.pre_order_traverse(node.left) if node.right: self.pre_order_traverse(node.right) def _insert_node(self, data, depth=0): if len(data) == 0: return None axis = depth % self.k data = sorted(data, key = lambda x: x[axis]) middle = len(data) // 2 return KDNode( data[middle], axis, self._insert_node(data[:middle], depth+1), self._insert_node(data[middle+1:], depth+1) ) def _visit(self, node, x, p=2): if node is not None: dis = x[node.axis] - node.data[node.axis] self._visit(node.left if dis < 0 else node.right, x, p) curr_dis = np.linalg.norm(x-node.data, p) heapq.heappushpop(self.knn, (-curr_dis, node)) if -(self.knn[0][0]) > abs(dis): self._visit(node.right if dis < 0 else node.left, x, p) if __name__ == "__main__": data = np.array([ [2,3], [5,4], [9,6], [4,7], [8,1], [7,2] ]) tree = KDTree(data) tree.construct() print(tree.search(np.array([3, 4.5]), 2))
通過呼叫 KDTree 的 search 方法即可實現查詢 x 的 k 近鄰。 結果為 \([(2,3)^T, (5,4)^T]\)