1. 程式人生 > >KD-Tree(C++實現)

KD-Tree(C++實現)

參考資料:

https://blog.csdn.net/dymodi/article/details/46830071

https://github.com/WiseDoge/libkdtree

作為存取高維資料的一種資料結構,k-d tree 在靜態查詢和插入方面的效率還是很高的。本文在這裡對 k-d tree 的內容作一些介紹,可能也會結合自己使用 k-d tree 的一些體驗作一些點評。其實,k-d tree 是早在1975年的時候由 Stanford 的 Bentley 提出來的。本文的內容也主要來自於他的兩篇最原始的文章 [Ben75] 和 [FBF77] 。

k-d tree 概述 與 插入操作(Insertion)
首先,k-d tree 也是二叉搜尋樹的一種,與常見的平衡二叉搜尋樹(BST)不同的是,在 k-d tree 中,每個節點記憶體儲的都是一條記錄(record),或者說是多維空間中的一個點,用一個向量來表示。而且在 k-d tree中,這個點也代表了空間中的一個區域。每個節點都有兩個子節點,而且兩個子節點各自代表的區域是父節點的區域一個劃分。

在一維的情形中,每條 record 都是由一個單獨的 key 來表示的。因此,對於 k-d tree 中的每個節點,key 值小於或者等於當前節點的 key 值的點就屬於左子樹,比當前節點 key 值大的就屬於右子樹。因此,這裡的 key 值就成為了一種鑑別器(discriminator)。而在 k 維的情況中,一條 record 是由 k 個 key 值來表示的,這裡每一維的 key 值都可以作為 discriminator 來將一個點向某個節點的左右子樹來分類。而在 k-d tree 中,discriminator 的選取是和該節點所在的層數有關的,即在根節點處,即第0層,按照第一維的 key 值來進行分類,第一維的 key 值小於等於根節點的第一維的 key 值的屬於根節點的左子樹,大於根節點的第一維的 key 值的屬於根節點的右子樹。然後在根節點的左右子節點的位置上,即第一層的位置上,根據第二維的 key 值來區分,以此類推。即第 k 層要比較的 key 值的維數為 D=L mod k+1D=L mod k+1 。其中L是當前節點所在的層數,其中根節點即為第0層。

按照 k-d tree 的規則依次插入(0,0), (-10, 10), (10, -10), (-40, -20), (-20, 11), (20, 0)這幾個點,我們可以得到如下左圖所示的 k-d tree,右圖是這幾個點在平面的示意圖。其中藍線表示該點處是以第一維的 key 值進行區分,紅線表示該點處是以第二維的 key 值進行區分。 


同時我們還可以看出,k-d tree 中每一個節點其實也代表了k維空間中的一個區域(region)。我們以上述幾個二維空間中的點為例。根節點 (0,0) 代表的是全平面,即 (-50, -50, 50, 50) 這樣一個區域,這裡的區域我們用 (xmin,ymin,xmax,ymax)(xmin,ymin,xmax,ymax) 來表示,因為根節點 (0,0) 是在第一維,即 xx 軸出進行區分的,因此它的左子節點就代表了左半平面,右子節點就代表了右半平面。即點 (-10, 10) 代表的是 (-50, -50, 0, 50) 這樣一個區域,點 (10, -10) 代表的是 (0, -50, 50, 50) 這樣一個區域。以此類推,在點 (-10, 10) 處,因為是第一層,因此按照第二維來區分,所以點 (-40, -20) 的第二維比點 (-10, 10) 小,就在左面;點 (-20, 11) 的第二維比點 (-10, 10) 大,就在右面。而且,左面的點 (-40, -20) 代表的是它的父節點的下半平面,即 (-50, -50, 0, 10) 這樣一個區域;右面的點 (-20, 11) 代表的是它的父節點的上半平面,即 (-50, 10, 0, 50) 這樣一個區域。

查詢操作(Searching)
上面我們介紹了 k-d tree 的原理和插入節點的過程,現在我們介紹下搜尋節點的過程。在 k-d tree 中對點進行搜尋的方法有很多。包括:(1)對所有維度進行匹配的特定點查詢(精確匹配);(2)對部分維度進行匹配的查詢;(3)對某個特定的區域內的點進行進行查詢;(4)查詢與特定點距離最近的幾個點。

上面的幾種搜尋演算法都在 [Ben75] 和 [FBF77] 兩篇文章中有詳細介紹,在這裡我們主要介紹(3),也就是我自己用過的區域查詢(Region Query)。區域查詢的目標是,在 k-d tree 所代表的空間內,如上面例子中提到的二維平面中的 (-50, -50, 50, 50) 這樣一個區域,給定一個矩形的區域(即在各個維度上給出這個區域的上下界),如在上面的例子中我們可以給定 (-45, -30, -30, -10) 這樣一個區域,查詢所有落在這個區域內的點。

區域查詢的主要方法如下:從根節點開始,考察該節點的 key 值所代表的點是否在待查詢的區域內,如果在待查區域內,就將這個節點放入一個全域性的列表中;在這之後,分別考察該節點的左右子節點所代表區域與待查詢的區域是否有交集,如果有,就遞迴地以該子節點作為根節點,進行上述操作,如果沒有就返回。在所有遞迴函式執行完後,我們可以得到一個全域性的列表,這個列表裡儲存的都是落在待查詢區域內的點。從上面的闡述中可以看出,查詢演算法的複雜度與待查的區域大小有很大關係。雖然根據 [LW77] 的結論,最差情況下區域查詢的複雜度會達到 O(kN1−1k)O(kN1−1k),其中 kk 是資料點的維度, NN 是 k-d tree 內節點的總個數,但 [Ben75, FB74] 的大量模擬都表明在進行超矩形(hyper-rectangular)區域的搜尋時,k-d tree 上的區域搜尋的表現相當不錯(reasonably well)。
 

刪除操作(Deletion)
其實,k-d tree 對刪除操作的支援並不很好,因為 k-d tree 本身不具備平衡性,動態進行的插入和刪除操作可能使得 k-d tree 退化成一個線性表。實際上也有關於平衡 k-d tree 的研究,如 [Rob81]。但可能是因為實現起來太複雜的原因,K-D-B tree 似乎沒有得到很多應用。

下面我們主要講一下刪除操作,對 k-d tree 內的節點進行刪除的原則是,對於一個沒有後繼結點的外部節點,刪除操作可以直接進行;對於有後繼結點的內部節點 PP,要做的就是從它的子節點中找到一個合適的節點 QQ 來放置到這個需要被刪除的節點的位置上。而所謂合適的節點,就是說如果 PP 節點是在第 JJ 個維度上進行分界的,那麼 QQ 就是 PP 的左子樹中 JJ 維上最大的節點,或者是 PP 的右子樹中 JJ 維上最小的節點,二者均可以。要將 QQ 節點替換到 PP 節點的位置上去,需要先將 QQ 節點從它原來的位置上刪除,因此上面所述的刪除操作也是一個遞迴實現的過程。

我自己寫的刪除操作因為要結合自己其他的應用,因此寫得很冗長,就不在這裡放出來了。

優化操作(Optimize)
優化操作是 k-d tree 的一種離線操作。我們都知道,當二叉樹隨著插入操作的進行,如果無法保證樹的平衡性,那麼在二叉樹上進行操作的複雜度會逐漸變差,極端情況下二叉樹會退化成為一個線性表。針對一個不平衡的 k-d tree,可以通過優化的操作來使其恢復平衡,以保證後續查詢操作的效率。

所謂優化操作,其實就是按照維度的次序,分別將節點進行排序。比如,對於一個需要優化的 k-d tree,對其所有的節點按照第一維度元素進行升序排序,然後最中間的一個作為根節點,然後左半部分的節點作為主子樹的節點,有半部分的節點作為又子樹的節點。然後分別對左右子樹的節點進行上述的處理,只不過參考的維度分別為第二維,第三維……

經過上面的處理,可以使得一個任意的 k-d tree 成為平衡的 k-d tree。

在這裡我們對 k-d tree 的內容進行一個小結,針對已有的 N 個數據點,每個點由一個 k 維的資料表徵,建立一個 k-d tree 的複雜度為 O(NlogN)O(Nlog⁡N),對已有的 k-d tree 進行優化的複雜度為 O(NlogN)O(NlogN),插入一個節點的複雜度為 O(logN)O(log⁡N),刪除一個節點的複雜度為 O(logN)O(log⁡N),,進行精確匹配的複雜度為 O(logN)O(log⁡N) ,查詢一個特定的區域的最差情況的複雜度為 O(kN1−1k)O(kN1−1k),但區域查詢的複雜度與區域大小有關,而且平均意義下的效果不錯。

 

程式碼如下:

kdtree.h

#ifndef KDTREE_H
#define KDTREE_H

// set dynamic link library
#if defined(_MSC_VER)
#define DLLExport __declspec(dllexport)
#else
#define DLLExport
#endif

// set c++
#ifdef __cplusplus
extern "C" {
#endif

#include <stdio.h>

struct DLLExport tree_node
{
    size_t id;
    size_t split;
    tree_node *left, *right;
};

struct DLLExport tree_model
{
    tree_node *root;
    const float *datas;
    const float *labels;
    size_t n_samples;
    size_t n_features;
    float p;
};

DLLExport void free_tree_memory(tree_node *root);
DLLExport tree_model* build_kdtree(const float *datas, const float *labels,
                                   size_t rows, size_t cols, float p);
DLLExport float* k_nearests_neighbor(const tree_model *model, const float *X_test,
                                     size_t len, size_t k, bool clf);
DLLExport void find_k_nearests(const tree_model *model, const float *coor,
                               size_t k, size_t *args, float *dists);

#ifdef __cplusplus
}
#endif


#endif

kdtree.cpp

#include "kdtree.h"

#include <algorithm>
#include <vector>
#include <cmath>
#include <tuple>
#include <unordered_map>
#include <stack>
#include <queue>
#include <cstring>
#include <cassert>
#include <cstdlib>

// Example:
//     int x = Malloc(int, 10);
//     int y = (int *)malloc(10 * sizeof(int));
#define Malloc(type, n) (type *)malloc((n)*sizeof(type))

// If you need to use Intel MKL to accelerate,
// you can cancel the next line comment.

//#define USE_INTEL_MKL


#ifdef USE_INTEL_MKL
#include <mkl.h>
#endif

// Clang does not support OpenMP.
#ifndef __clang__

#include <omp.h>

#endif

// 釋放一顆二叉樹記憶體的非遞迴演算法
DLLExport void free_tree_memory(tree_node *root) {
    std::stack<tree_node *> node_stack;
    tree_node *p;
    node_stack.push(root);
    while (!node_stack.empty()) {
        p = node_stack.top();
        node_stack.pop();
        if (p->left)
            node_stack.push(p->left);
        if (p->right)
            node_stack.push(p->right);
        free(p);
    }
}


class KDTree {
public:
    KDTree(){}

    KDTree(tree_node *root, const float *datas, size_t rows, size_t cols, float p);

    KDTree(const float *datas, const float *labels,
           size_t rows, size_t cols, float p, bool free_tree = true);

    ~KDTree();

    tree_node *GetRoot() { return root; }

    std::vector<std::tuple<size_t, float>> FindKNearests(const float *coor, size_t k);

    std::tuple<size_t, float> FindNearest(const float *coor, size_t k) { return FindKNearests(coor, k)[0]; }

    void CFindKNearests(const float *coor, size_t k, size_t *args, float *dists);


private:
    // The sample with the largest distance from point `coor`
    // is always at the top of the heap.
    struct neighbor_heap_cmp {
        bool operator()(const std::tuple<size_t, float> &i,
                        const std::tuple<size_t, float> &j) {
            return std::get<1>(i) < std::get<1>(j);
        }
    };

    typedef std::tuple<size_t, float> neighbor;
    typedef std::priority_queue<neighbor,
            std::vector<neighbor>, neighbor_heap_cmp> neighbor_heap;

    // 搜尋 K-近鄰時的堆(大頂堆),堆頂始終是 K-近鄰中樣本點最遠的點
    neighbor_heap k_neighbor_heap_;
    // 求距離時的 p, dist(x, y) = pow((x^p + y^p), 1/p)
    float p;
    // 析構時是否釋放樹的記憶體
    bool free_tree_;
    // 樹根結點
    tree_node *root;
    // 訓練集
    const float *datas;
    // 訓練集的樣本數
    size_t n_samples;
    // 每個樣本的維度
    size_t n_features;
    // 訓練集的標籤
    const float *labels;
    // 尋找中位數時用到的快取池
    std::tuple<size_t, float> *get_mid_buf_;
    // 搜尋 K 近鄰時的快取池,如果已經搜尋過點 i,令 visited_buf[i] = True
    bool *visited_buf_;

#ifdef USE_INTEL_MKL
    // 使用 Intel MKL 庫時的快取
    float *mkl_buf_;
#endif


    // 初始化快取
    void InitBuffer();

    // 建樹
    tree_node *BuildTree(const std::vector<size_t> &points);

    // 求一組數的中位數
    std::tuple<size_t, float> MidElement(const std::vector<size_t> &points, size_t dim);

    // 入堆
    void HeapStackPush(std::stack<tree_node *> &paths, tree_node *node, const float *coor, size_t k);

    // 獲取訓練集中第 sample 個樣本點第 dim 的值
    float GetDimVal(size_t sample, size_t dim) {
        return datas[sample * n_features + dim];
    }

    // 求點 coor 距離訓練集第 i 個點的距離
    float GetDist(size_t i, const float *coor);

    // 尋找切分點
    size_t FindSplitDim(const std::vector<size_t> &points);

};

// 找到一棵樹的 K近鄰。Ki 的 id 和  Ki 與 coor 之間的距離 分別儲存在   args 和 dists 中
DLLExport
void find_k_nearests(const tree_model *model, const float *coor,
                     size_t k, size_t *args, float *dists) {
    KDTree tree(model->root, model->datas, model->n_samples, model->n_features, model->p);
    std::vector<std::tuple<size_t, float>> k_nearest = tree.FindKNearests(coor, k);
    for (size_t i = 0; i < k; ++i) {
        args[i] = std::get<0>(k_nearest[i]);
        dists[i] = std::get<1>(k_nearest[i]);
    }
}

// 建立一棵 KD-Tree
DLLExport
tree_model *build_kdtree(const float *datas, const float *labels,
                         size_t rows, size_t cols, float p) {
    KDTree tree(datas, labels, rows, cols, p, false);
    tree_model *model = Malloc(tree_model, 1);
    model->datas = datas;
    model->labels = labels;
    model->n_features = cols;
    model->n_samples = rows;
    model->root = tree.GetRoot();
    model->p = p;
    return model;
}

// 求平均值,用於迴歸問題
float mean(const float *arr, size_t len) {
    float ans = 0.0;
    for (size_t i = 0; i < len; ++i)
        ans += arr[i];
    return ans / len;
}

// 投票,用於分類問題
float vote(const float *arr, size_t len) {
    std::unordered_map<int, size_t> counter;
    for (size_t i = 0; i < len; ++i) {
        auto t = static_cast<int>(arr[i]);
        if (counter.find(t) == counter.end())
            counter.insert(std::unordered_map<int, size_t>::value_type(t, 1));
        else
            counter[t] += 1;
    }
    float cur_arg_max = 0;
    size_t cur_max = 0;
    for (auto &i : counter) {
        if (i.second >= cur_max) {
            cur_arg_max = static_cast<float>(i.first);
            cur_max = i.second;
        }
    }
    return cur_arg_max;
}

DLLExport float *
k_nearests_neighbor(const tree_model *model, const float *X_test, size_t len, size_t k, bool clf) {
    float *ans = Malloc(float, len);
    size_t *args;
    float *dists, *y_pred;
    size_t arr_len;
    int i, j;

#ifdef USE_INTEL_MKL
    int n_procs = omp_get_num_procs();
    assert(n_procs < 100);
    KDTree *trees[100];
    for (size_t i = 0; i < n_procs; ++i)
        trees[i] = new KDTree(model->root, model->datas, model->n_samples, model->n_features, model->p);
    arr_len = k * n_procs;
#else
    arr_len = k;
    KDTree tree(model->root, model->datas, model->n_samples, model->n_features, model->p);
#endif

    args = Malloc(size_t, arr_len);
    dists = Malloc(float, arr_len);
    y_pred = Malloc(float, arr_len);

#ifdef USE_INTEL_MKL
#pragma omp parallel for
    for (i = 0; i < len; ++i)
    {
        int thread_num = omp_get_thread_num();
        trees[thread_num]->CFindKNearests(X_test + i * model->n_features,
            k, args + k * thread_num, dists + k * thread_num);
        for (j = 0; j < k; ++j)
            y_pred[j + k * thread_num] = model->labels[args[j + k * thread_num]];
        if (clf)
            ans[i] = vote(y_pred + k * thread_num, k);
        else
            ans[i] = mean(y_pred + k * thread_num, k);
    }
    for (size_t i = 0; i < n_procs; ++i)
        delete trees[i];

#else
    for (i = 0; i < len; ++i) {
        tree.CFindKNearests(X_test + i * model->n_features, k, args, dists);
        for (j = 0; j < k; ++j)
            y_pred[j] = model->labels[args[j]];
        if (clf)
            ans[i] = vote(y_pred, k);
        else
            ans[i] = mean(y_pred, k);
    }
#endif
    free(args);
    free(y_pred);
    free(dists);
    return ans;
}


inline KDTree::KDTree(tree_node *root, const float *datas, size_t rows, size_t cols, float p) :
        root(root), datas(datas), n_samples(rows),
        n_features(cols), p(p), free_tree_(false) {
    InitBuffer();
    labels = nullptr;
}

inline KDTree::KDTree(const float *datas, const float *labels, size_t rows, size_t cols, float p, bool free_tree) :
        datas(datas), labels(labels), n_samples(rows), n_features(cols), p(p), free_tree_(free_tree) {
    std::vector<size_t> points;
    for (size_t i = 0; i < n_samples; ++i)
        points.emplace_back(i);
    InitBuffer();
    root = BuildTree(points);
}

inline KDTree::~KDTree() {
    delete[]get_mid_buf_;
    delete[]visited_buf_;
#ifdef USE_INTEL_MKL
    free(mkl_buf_);
#endif
    if (free_tree_)
        free_tree_memory(root);
}

std::vector<std::tuple<size_t, float>> KDTree::FindKNearests(const float *coor, size_t k) {
    std::memset(visited_buf_, 0, sizeof(bool) * n_samples);
    std::stack<tree_node *> paths;
    tree_node *p = root;

    while (p) {
        HeapStackPush(paths, p, coor, k);
        p = coor[p->split] <= GetDimVal(p->id, p->split) ? p = p->left : p = p->right;
    }
    while (!paths.empty()) {
        p = paths.top();
        paths.pop();

        if (!p->left && !p->right)
            continue;

        if (k_neighbor_heap_.size() < k) {
            if (p->left)
                HeapStackPush(paths, p->left, coor, k);
            if (p->right)
                HeapStackPush(paths, p->right, coor, k);
        } else {
            float node_split_val = GetDimVal(p->id, p->split);
            float coor_split_val = coor[p->split];
            float heap_top_val = std::get<1>(k_neighbor_heap_.top());
            if (coor_split_val > node_split_val) {
                if (p->right)
                    HeapStackPush(paths, p->right, coor, k);

                if ((coor_split_val - node_split_val) < heap_top_val && p->left)
                    HeapStackPush(paths, p->left, coor, k);
            } else {
                if (p->left)
                    HeapStackPush(paths, p->left, coor, k);
                if ((node_split_val - coor_split_val) < heap_top_val && p->right)
                    HeapStackPush(paths, p->right, coor, k);
            }
        }
    }
    std::vector<std::tuple<size_t, float>> res;

    while (!k_neighbor_heap_.empty()) {
        res.emplace_back(k_neighbor_heap_.top());
        k_neighbor_heap_.pop();
    }
    return res;
}

void KDTree::CFindKNearests(const float *coor, size_t k, size_t *args, float *dists) {
    std::vector<std::tuple<size_t, float>> k_nearest = FindKNearests(coor, k);
    for (size_t i = 0; i < k; ++i) {
        args[i] = std::get<0>(k_nearest[i]);
        dists[i] = std::get<1>(k_nearest[i]);
    }
}


// 初始化快取

inline void KDTree::InitBuffer() {
    get_mid_buf_ = new std::tuple<size_t, float>[n_samples];
    visited_buf_ = new bool[n_samples];

#ifdef USE_INTEL_MKL
    // 要與 C 程式碼互動,所以用 C 的方式申請記憶體
    mkl_buf_ = Malloc(float, n_features);
#endif
}

tree_node *KDTree::BuildTree(const std::vector<size_t> &points) {
    size_t dim = FindSplitDim(points);
    std::tuple<size_t, float> t = MidElement(points, dim);
    size_t arg_mid_val = std::get<0>(t);
    float mid_val = std::get<1>(t);

    tree_node *node = Malloc(tree_node, 1);
    node->left = nullptr;
    node->right = nullptr;
    node->id = arg_mid_val;
    node->split = dim;
    std::vector<size_t> left, right;
    for (auto &i : points) {
        if (i == arg_mid_val)
            continue;
        if (GetDimVal(i, dim) <= mid_val)
            left.emplace_back(i);
        else
            right.emplace_back(i);
    }
    if (!left.empty())
        node->left = BuildTree(left);
    if (!right.empty())
        node->right = BuildTree(right);
    return node;
}

std::tuple<size_t, float> KDTree::MidElement(const std::vector<size_t> &points, size_t dim) {
    size_t len = points.size();
    for (size_t i = 0; i < points.size(); ++i)
        get_mid_buf_[i] = std::make_tuple(points[i], GetDimVal(points[i], dim));
    std::nth_element(get_mid_buf_,
                     get_mid_buf_ + len / 2,
                     get_mid_buf_ + len,
                     [](const std::tuple<size_t, float> &i, const std::tuple<size_t, float> &j) {
                         return std::get<1>(i) < std::get<1>(j);
                     });
    return get_mid_buf_[len / 2];
}


inline void KDTree::HeapStackPush(std::stack<tree_node *> &paths, tree_node *node, const float *coor, size_t k) {
    paths.emplace(node);
    size_t id = node->id;
    if (visited_buf_[id])
        return;
    visited_buf_[id] = true;
    float dist = GetDist(id, coor);
    std::tuple<size_t, float> t(id, dist);
    if (k_neighbor_heap_.size() < k)
        k_neighbor_heap_.push(t);

    else if (std::get<1>(t) < std::get<1>(k_neighbor_heap_.top())) {
        k_neighbor_heap_.pop();
        k_neighbor_heap_.push(t);
    }
}

#ifdef USE_INTEL_MKL
inline float KDTree::GetDist(size_t i, const float *coor) {
    size_t idx = i * n_features;
    vsSub(n_features, datas + idx, coor, mkl_buf_);
    vsPowx(n_features, mkl_buf_, p, mkl_buf_);
    float dist = cblas_sasum(n_features, mkl_buf_, 1);
    return static_cast<float>(pow(dist, 1.0 / p));
}
#else

inline float KDTree::GetDist(size_t i, const float *coor) {
    float dist = 0.0;
    size_t idx = i * n_features;
#pragma omp parallel for reduction(+:dist)
    for (int t = 0; t < n_features; ++t)
        dist += pow(datas[idx + t] - coor[t], p);
    return static_cast<float>(pow(dist, 1.0 / p));
}

#endif

size_t KDTree::FindSplitDim(const std::vector<size_t> &points) {
    if (points.size() == 1)
        return 0;
    size_t cur_best_dim = 0;
    float cur_largest_spread = -1;
    float cur_min_val;
    float cur_max_val;
    for (size_t dim = 0; dim < n_features; ++dim) {
        cur_min_val = GetDimVal(points[0], dim);
        cur_max_val = GetDimVal(points[0], dim);
        for (const auto &id : points) {
            if (GetDimVal(id, dim) > cur_max_val)
                cur_max_val = GetDimVal(id, dim);
            else if (GetDimVal(id, dim) < cur_min_val)
                cur_min_val = GetDimVal(id, dim);
        }

        if (cur_max_val - cur_min_val > cur_largest_spread) {
            cur_largest_spread = cur_max_val - cur_min_val;
            cur_best_dim = dim;
        }
    }
    return cur_best_dim;
}

main.cpp

#include "kdtree.h"
#include <stdio.h>
#include <stdlib.h>


int main() {
    float datas[100] = {1.3, 1.3, 1.3,
                         8.3, 8.3, 8.3,
                         2.3, 2.3, 2.3,
                         1.2, 1.2, 1.2,
                         7.3, 7.3, 7.3,
                         9.3, 9.3, 9.3,
                         15, 15, 15,
                         3, 3, 3,
                         1.1, 1.1, 1.1,
                         12, 12, 12,
                         4, 4, 4,
                         5, 5, 5};
    float labels[100];
    for(size_t i = 0; i < 12; ++i)
        labels[i] = (float)i;
    tree_model *model = build_kdtree(datas, labels, 12, 3, 2);
    float test[6] = {3, 3, 3, 4, 4, 4};
    size_t args[100];
    float dists[100];
    find_k_nearests(model, test, 5, args, dists);  // 這裡只搜尋了(3,3,3)的K鄰近點

    printf("K-Nearest: \n");
    for (size_t i = 0; i < 5; ++i) {
        printf("ID %d, Dist %.2f\n", args[i], dists[i]);
    }

    float *ans = k_nearests_neighbor(model, test, 2, 5, false);  // 形參2表示:test中有2個樣本待測
    printf("k Nearest Neighbor Regressor: \n%.2f %.2f\n", ans[0], ans[1]);

//    tree_node *root = model->root;


    free(ans);
    free_tree_memory(model->root);

    return 0;
}

執行結果