K-D樹 C++實現
阿新 • • 發佈:2019-01-07
K-D樹主要是為了實現機器學習演算法中的K近鄰演算法,單純的K-D樹只能實現最近鄰,但是結合優先佇列就可以實現K近鄰了,這裡只是把K-D樹簡單的實現了一下,經過簡單測試,暫時沒有發現重大bug。
#ifndef KDTREE_H #define KDTREE_H #include <cassert> #include <algorithm> #include <cstddef> #include <vector> #include <cmath> #include <iostream> using ::std::vector; using ::std::cout; using ::std::endl; namespace sx { typedef float DataType; typedef unsigned int UInt; struct Feature { vector<DataType> data; int id; Feature() {} Feature(const vector<DataType> & d, int i) : data(d), id(i) {} } /* optional variable list */; template <UInt K> class KDTree { public: KDTree(); virtual ~KDTree(); KDTree(const KDTree & rhs); const KDTree & operator = (const KDTree & rhs); void Clean(); void Build(const vector<Feature> & matrix_feature); int FindNearestFeature(const Feature & target) const; int FindNearestFeature(const Feature & target, DataType & min_difference) const; void Show() const; private: struct KDNode { KDNode * left; KDNode * right; Feature feature; int depth; KDNode(const Feature & f, KDNode * lt, KDNode * rt, int d) : feature(f), left(lt), right(rt), depth(d) {} } /* optional variable list */; KDNode * root_; struct Comparator { int index_comparator; Comparator(int ix) : index_comparator(ix) {} bool operator () (const Feature & lhs, const Feature & rhs) { return lhs.data[index_comparator] < rhs.data[index_comparator]; } } /* optional variable list */; KDNode * Clone(KDNode * t) const; void Clean(KDNode * & t); void SortFeature(vector<Feature> & features, int index); void Build(const vector<Feature> & matrix_feature, KDNode * & t, int depth); DataType Feature2FeatureDifference(const Feature & f1, const Feature & f2) const; int FindNearestFeature(const Feature & target, DataType & min_difference, KDNode * t) const; void Show(KDNode * t) const; }; template <UInt K> KDTree<K>:: KDTree() : root_(NULL) {} template <UInt K> KDTree<K>:: ~KDTree() { Clean(); } template <UInt K> KDTree<K>:: KDTree(const KDTree & rhs) { *this = rhs; } template <UInt K> const KDTree<K> & KDTree<K>:: operator = (const KDTree & rhs) { if (this != &rhs) { Clean(); root_ = Clone(rhs.root_); } return *this; } template <UInt K> void KDTree<K>:: Clean() { Clean(root_); } template <UInt K> void KDTree<K>:: Build(const vector<Feature> & matrix_feature) { if (matrix_feature.size() != 0) { assert(matrix_feature[0].data.size() == K); } Build(matrix_feature, root_, 0); } template <UInt K> int KDTree<K>:: FindNearestFeature(const Feature & target) const { DataType min_difference; return FindNearestFeature(target, min_difference); } template <UInt K> int KDTree<K>:: FindNearestFeature(const Feature & target, DataType & min_difference) const { min_difference = 10e8; return FindNearestFeature(target, min_difference, root_); } template <UInt K> void KDTree<K>:: Show() const { Show(root_); return ; } template <UInt K> typename KDTree<K>::KDNode * KDTree<K>:: Clone(KDNode * t) const { if (NULL == t) { return NULL; } return new KDNode(t->feature, t->left, t->right, t->depth); } template <UInt K> void KDTree<K>:: Clean(KDNode * & t) { if (t != NULL) { Clean(t->left); Clean(t->right); delete t; } t = NULL; } template <UInt K> void KDTree<K>:: SortFeature(vector<Feature> & features, int index) { sort(features.begin(), features.end(), Comparator(index)); } template <UInt K> void KDTree<K>:: Build(const vector<Feature> & matrix_feature, KDNode * & t, int depth) { if (matrix_feature.size() == 0) { t = NULL; return ; } vector<Feature> temp_feature = matrix_feature; vector<Feature> left_feature; vector<Feature> right_feature; SortFeature(temp_feature, depth % K); int length = (int)temp_feature.size(); int middle_position = length / 2; t = new KDNode(temp_feature[middle_position], NULL, NULL, depth); for (int i = 0; i < middle_position; ++i) { left_feature.push_back(temp_feature[i]); } for (int i = middle_position + 1; i < length; ++i) { right_feature.push_back(temp_feature[i]); } Build(left_feature, t->left, depth + 1); Build(right_feature, t->right, depth + 1); return ; } template <UInt K> DataType KDTree<K>:: Feature2FeatureDifference(const Feature & f1, const Feature & f2) const { DataType diff = 0.0; assert(f1.data.size() == f2.data.size()); for (int i = 0; i < (int)f1.data.size(); ++i) { diff += (f1.data[i] - f2.data[i]) * (f1.data[i] - f2.data[i]); } return sqrt(diff); } template <UInt K> int KDTree<K>:: FindNearestFeature(const Feature & target, DataType & min_difference, KDNode * t) const { if (NULL == t) { return -1; } DataType diff_parent = Feature2FeatureDifference(target, t->feature); DataType diff_left = 10e8; DataType diff_right = 10e8; int result_parent = -1; int result_left = -1; int result_right = -1; if (diff_parent < min_difference) { min_difference = diff_parent; result_parent = t->feature.id; } if (NULL == t->left && NULL == t->right) { return result_parent; } if (NULL == t->left /* && t->right != NULL */) { result_right = FindNearestFeature(target, diff_right, t->right); if (diff_right < min_difference) { min_difference = diff_right; result_parent = result_right; } return result_parent; } if (NULL == t->right /* && t->left != NULL */) { result_left = FindNearestFeature(target, diff_left, t->left); if (diff_left < min_difference) { min_difference = diff_left; result_parent = result_left; } return result_parent; } int index_feature = t->depth % K; DataType diff_boundary = fabs(target.data[index_feature] - t->feature.data[index_feature]); if (target.data[index_feature] < t->feature.data[index_feature]) { result_left = FindNearestFeature(target, diff_left, t->left); if (diff_left < min_difference) { min_difference = diff_left; result_parent = result_left; } if (diff_boundary < Feature2FeatureDifference(target, t->left->feature)) { result_right = FindNearestFeature(target, diff_right, t->right); if (diff_right < min_difference) { min_difference = diff_right; result_parent = result_right; } } } else { result_right = FindNearestFeature(target, diff_right, t->right); if (diff_right < min_difference) { min_difference = diff_right; result_parent = result_right; } if (diff_boundary < Feature2FeatureDifference(target, t->right->feature)) { result_left = FindNearestFeature(target, diff_left, t->left); if (diff_left < min_difference) { min_difference = diff_left; result_parent = result_left; } } } return result_parent; } template <UInt K> void KDTree<K>::Show(KDNode * t) const { cout << "ID: " << t->feature.id << endl; cout << "Data: "; for (int i = 0; i < (int)t->feature.data.size(); ++i) { cout << t->feature.data[i] << " "; } cout << endl; if (t->left != NULL) { cout << "Left: " << t->feature.id << " -> " << t->left->feature.id << endl; Show(t->left); } if (t->right != NULL) { cout << "Right: " << t->feature.id << " -> " << t->right->feature.id << endl; Show(t->right); } return ; } } /* sx */ #endif /* end of include guard: KDTREE_H */
下面是測試的main.cc檔案
// ============================================================================= // // Filename: main.cc // // Description: K-D Tree // // Version: 1.0 // Created: 04/11/2013 04:28:02 PM // Revision: none // Compiler: g++ // // Author: Geek SphinX(Perstare et Praestare),
[email protected] // Organization: Hefei University of Technology // // ============================================================================= #include "KDTree.h" #include <vector> using namespace sx; using namespace std; int main(int argc, char *argv[]) { KDTree<2> kdtree; vector<Feature> vf; int idx = 0; for (int i = 0; i < 5; ++i) { for (int j = 0; j < 5; ++j) { vector<DataType> vd; vd.push_back(i); vd.push_back(j); vf.push_back(Feature(vd, idx++)); } } kdtree.Build(vf); kdtree.Show(); Feature target; int n; DataType x, y; cin >> n; while (n--) { cin >> x >> y; vector<DataType> vd; vd.push_back(x); vd.push_back(y); target = Feature(vd, 0); DataType md; int t = kdtree.FindNearestFeature(target, md); cout << "Result is " << t << endl; for (int i = 0; i < (int)vf[t].data.size(); ++i) { cout << vf[t].data[i] << " "; } cout << endl; cout << "Minimum Difference is " << md << endl; } return 0; }