數組索引的kdtree建立及簡明快速的k近鄰搜索方法
1. kdtree概念
kd樹(k-dimensional樹的簡稱),是一種分割k維數據空間的數據結構,主要應用於多維空間關鍵數據的搜索,如範圍搜索和最近鄰搜索。
如下圖所示,在既定的分割維度上,每一個根節點的值均大於其左子樹,並小於其右子樹。這樣的二叉樹,對於搜索某個點的最臨近點或k近鄰點,是十分高效快速的。
2. 建立kdtree
建立kdtree,主要有兩步操作:選擇合適的分割維度,選擇中值節點作為分割節點。分割維度的選擇遵循的原則是,選擇範圍最大的緯度,也即是方差最大的緯度作為分割維度;分割節點的選擇原則是,將這一維度的數據進行排序,選擇正中間的節點作為分割節點,確保節點左邊的點的維度值小於節點的維度值,節點右邊的點的維度值大於節點的維度值。
建立kdtree可遵循以下步驟:
1) 建立一維數組,存儲每一個點的索引,並進行隨機打亂。
2) 定義合適的kdtree函數定義,方便進行遞歸建樹。
3) 編寫分割維度函數
4) 編寫選擇分割節點函數
5) kdtree函數功能實現:選擇分割維度,選擇分割節點,將節點左邊的數據進行遞歸建立左子樹,將節點右邊的數據進行遞歸建立右子樹
下面通過實際代碼,講解kdtree建立的過程:
1)數據及索引的存儲定義
無論是數據還是索引均存儲在一維數組中,通過二維指針數組來索引,用一個指針數組來存儲每一維數據的起始地址,用另一個指針數組來存儲每一類索引的起始位置,比如分割維度、父節點、左子樹、右子樹
/* * dataPtr一維數組表示多維數組 * 數據排布方式:{[x1, x2, x3……], [y1, y2, y3……], [z1, z2, z3……], ……} */ /* * 所有數據存儲在一維數組dataPtr裏,data分別是x/y/z等數據的起始地址 * 因此,建樹及knn只需傳遞數據的索引編號即可 */ float **data; float *dataPtr; int **tree; // 4 * n :分割維度、父節點、左子樹、右子樹int *treePtr; // 使用一維數據表示二維數組,存儲建立的kdtree索引
對定義的數組進行初始化操作:
1 int ZtKDTree::setSize(int dimension, unsigned int sz) 2 { 3 nDimension = dimension; // 數據的維度 4 treeSize = sz; // 數據的總數 5 6 if (nDimension > 0 && treeSize > 0) 7 { 8 offset = new double[nDimension]; 9 10 tree = new int *[4]; 11 treePtr = new int[4 * treeSize]; 12 for (int i = 0; i < 4; i++) 13 { 14 tree[i] = treePtr + i * treeSize; 15 } 16 17 data = new float *[nDimension]; 18 dataPtr = new float[nDimension * sz]; 19 for (int i = 0; i < nDimension; i++) 20 { 21 data[i] = dataPtr + i * treeSize; 22 } 23 } 24 25 return 0; 26 }
2) kdtree建立準備,建立一維數組存儲數據索引,定義建樹函數
使用一維數組存儲每一個數據的索引,並進行隨機打亂,建樹過程中,可以通過索引來訪問數據,並且不會打亂原來數據的順序,快速排序等操作也不必操作數據,只需操作索引即可
1 int buildTree() 2 { 3 std::vector<int> vtr(treeSize); 4 5 for (int i = 0; i < treeSize; i++) 6 { 7 vtr[i] = i; 8 } 9 10 std::random_shuffle(vtr.begin(), vtr.end()); 11 12 treeRoot = buildTree(&vtr[0], treeSize, -1); // 根節點的父節點是-1 13 14 return treeRoot; 15 } 16 17 // 建立kdtree函數 18 int buildTree(int *indices, int count, int parent)
3)分割維度函數編寫
分割維度的選擇至關重要,選擇合適的維度,可提高建樹效率及搜索效率。計算當前空間的所有數據每一維度的方差,選擇方差最大的維度作為分割維度,並順便傳出維度均值,以用於節點選擇函數。
1 int chooseSplitDimension(int *ids, int sz, float &key) 2 { 3 int split = 0; 4 5 float *var = new float[nDimension]; 6 float *mean = new float[nDimension]; 7 8 int cnt = std::min((int)SAMPLE_MEAN, sz);/* cnt = sz;*/ 9 double rt = 1.0 / cnt; 10 11 for (int i = 0; i < nDimension; i++) 12 { 13 double sum1 = 0, sum2 = 0; 14 for (int j = 0; j < cnt; j++) 15 { 16 sum1 += rt * data[i][ids[j]] * data[i][ids[j]]; 17 sum2 += rt * data[i][ids[j]]; 18 } 19 var[i] = sum1 - sum2 * sum2; 20 mean[i] = sum2; 21 } 22 23 double max = 0; 24 25 for (int i = 0; i < nDimension; i++) 26 { 27 if (var[i] > max) 28 { 29 key = mean[i]; 30 max = var[i]; 31 split = i; 32 } 33 } 34 35 delete[] var; 36 delete[] mean; 37 38 return split; 39 }View Code
4)節點選擇函數編寫
這步操作主要是選擇中值節點,但是並不是說要把全部數據進行排序,排序太費時了。使用維度均值進行一趟快速排序,將數據分為兩部分,大於均值的數據、小於均值的數據,然後從小於均值的空間中選擇最大的節點作為父節點,這樣就保證左子樹所有節點小於父節點,右子樹所有節點大於父節點。
1 int chooseMiddleNode(int *ids, int sz, int dim, float key) 2 { 3 int left = 0; 4 int right = sz - 1; 5 6 while (1) 7 { 8 while (left <= right && data[dim][ids[left]] <= key) //左邊找比key大的值 9 ++left; 10 11 while (left <= right && data[dim][ids[right]] >= key) //右邊找比key小的值 12 --right; 13 14 if (left > right) 15 break; 16 17 std::swap(ids[left], ids[right]); 18 ++left; 19 --right; 20 } 21 22 23 // 找出左子樹的最大值作為根節點 24 float max = -9999999; 25 int maxIndex = 0; 26 for (int i = 0; i < left; i++) 27 { 28 if (data[dim][ids[i]] > max) 29 { 30 max = data[dim][ids[i]]; 31 maxIndex = i; 32 } 33 } 34 35 if (maxIndex != left - 1) 36 { 37 std::swap(ids[maxIndex], ids[left - 1]); 38 } 39 40 return left - 1; 41 }View Code
5)建樹
完成以上工作後,建樹就很簡單了
1 int buildTree(int *indices, int count, int parent) 2 { 3 if (count == 1) 4 { 5 int rd = indices[0]; 6 tree[0][rd] = 0; 7 tree[1][rd] = parent; 8 tree[2][rd] = -1; 9 tree[3][rd] = -1; 10 11 return rd; 12 } 13 else 14 { 15 float key = 0; 16 int split = chooseSplitDimension(indices, count, key); 17 int idx = chooseMiddleNode(indices, count, split, key); 18 19 // rd 是實際點的下標, idx是點的索引數組的下標 20 int rd = indices[idx]; 21 22 tree[0][rd] = split; // 分割維度 23 tree[1][rd] = parent; 24 25 if (idx > 0) 26 { 27 tree[2][rd] = buildTree(indices, idx, rd); 28 } 29 else 30 { 31 tree[2][rd] = -1; 32 } 33 34 if (idx + 1 < count) 35 { 36 tree[3][rd] = buildTree(indices + idx + 1, count - idx - 1, rd); 37 } 38 else 39 { 40 tree[3][rd] = -1; 41 } 42 43 return rd; 44 } 45 }View Code
3. k近鄰搜索
最臨近搜索即是查找距離查找點最近的k個點。在講述k臨近搜索之前,先講述下最近鄰搜索的概念。
最近鄰搜索的基本思路是:從根節點開始,通過二叉樹搜索,如果節點的分割維度值小於查找點的維度值表示查找點位於左子樹空間中,則進入左子樹,如果大於則進入右子樹,直到達到葉子節點為止,將搜索路徑上的每一個節點都加入到路徑中;然後再回溯搜索路徑,並判斷未加入路徑的其他子節點空間中是否可能有距離搜索點更近的節點,如果有可能,則遍歷子節點空間,並將遍歷到的節點加入到搜索路徑中,重復這個過程直到搜索路徑為空。
理解了最近鄰搜索的思路,就很容易實現k近鄰搜索了,k近鄰搜索的思路是:同樣是先遍歷kdtree,將遍歷到的節點加入到搜索路徑中,然後回溯路徑;建立最大堆,在回溯路徑中,將小於堆頂最大距離的節點加入堆,直到搜索路徑為空。
實際實現過程中,需要註意的是,先出隊列的是葉子節點,距離查找點比較近,最先加入最大堆,從而堆頂距離比較小,在最大堆不滿時,進行距離判斷,可能會將在k近鄰範圍內的節點排除掉,因此預先加入一個極大距離節點,可避免最大堆不滿時,排除掉正確的節點。
1 struct NearestNode 2 { 3 int node; 4 float distance; 5 NearestNode() 6 { 7 node = 0; 8 distance = 0; 9 } 10 NearestNode(int n, float d) 11 { 12 node = n; 13 distance = d; 14 } 15 }; 16 17 struct cmp // 將最大的元素放在隊首 18 { 19 bool operator()(NearestNode a, NearestNode b) 20 { 21 return a.distance < b.distance; 22 } 23 }; 24 25 int findKNearests(float *p, int k, int *res) 26 { 27 std::priority_queue<NearestNode, std::vector<NearestNode>, cmp> kNeighbors; 28 std::stack<int> paths; 29 30 // 記錄查找路徑 31 int node = treeRoot; 32 while (node > -1) 33 { 34 paths.emplace(node); 35 36 node = p[tree[0][node]] <= data[tree[0][node]][node] ? tree[2][node] : tree[3][node]; 37 } 38 39 // 預先加入一個極大節點 40 kNeighbors.emplace(-1, 9999999); 41 42 // 回溯路徑 43 float distance = 0; 44 while (!paths.empty()) 45 { 46 node = paths.top(); 47 paths.pop(); 48 49 distance = computeDistance(p, node); 50 if (kNeighbors.size() < k) 51 { 52 kNeighbors.emplace(node, distance); 53 } 54 else 55 { 56 if (distance < kNeighbors.top().distance) 57 { 58 kNeighbors.pop(); 59 kNeighbors.emplace(node, distance); 60 } 61 } 62 63 if (tree[2][node] + tree[3][node] > -2) 64 { 65 int dim = tree[0][node]; 66 if (p[dim] > data[dim][node]) 67 { 68 if (p[dim] - data[dim][node] < kNeighbors.top().distance && tree[2][node] > -1) 69 { 70 int reNode = tree[2][node]; 71 while (reNode > -1) 72 { 73 paths.emplace(reNode); 74 75 reNode = p[tree[0][reNode]] <= data[tree[0][reNode]][reNode] ? tree[2][reNode] : tree[3][reNode]; 76 } 77 } 78 } 79 else 80 { 81 if (data[dim][node] - p[dim] < kNeighbors.top().distance && tree[3][node] > -1) 82 { 83 int reNode = tree[3][node]; 84 while (reNode > -1) 85 { 86 paths.emplace(reNode); 87 88 reNode = p[tree[0][reNode]] <= data[tree[0][reNode]][reNode] ? tree[2][reNode] : tree[3][reNode]; 89 } 90 } 91 } 92 } 93 } 94 95 if (!res) 96 { 97 res = new int[k]; 98 } 99 100 int i = kNeighbors.size(); 101 while (!kNeighbors.empty()) 102 { 103 res[--i] = kNeighbors.top().node; 104 kNeighbors.pop(); 105 } 106 107 return 0; 108 }View Code
數組索引的kdtree建立及簡明快速的k近鄰搜索方法