1. 程式人生 > >k近鄰法的C++實現

k近鄰法的C++實現

#include <iostream>
#include <vector>
#include <algorithm>
#include <string>
#include <cmath>
using namespace std;




struct KdTree{
    vector<double> root;
    KdTree* parent;
    KdTree* leftChild;
    KdTree* rightChild;
    //預設建構函式
    KdTree(){parent = leftChild = rightChild = NULL;}
    //判斷kd樹是否為空
    bool isEmpty()
    {
        return root.empty();
    }
    //判斷kd樹是否只是一個葉子結點
    bool isLeaf()
    {
        return (!root.empty()) &&
            rightChild == NULL && leftChild == NULL;
    }
    //判斷是否是樹的根結點
    bool isRoot()
    {
        return (!isEmpty()) && parent == NULL;
    }
    //判斷該子kd樹的根結點是否是其父kd樹的左結點
    bool isLeft()
    {
        return parent->leftChild->root == root;
    }
    //判斷該子kd樹的根結點是否是其父kd樹的右結點
    bool isRight()
    {
        return parent->rightChild->root == root;
    }
};

int data[6][2] = {{2,3},{5,4},{9,6},{4,7},{8,1},{7,2}};

template<typename T>
vector<vector<T> > Transpose(vector<vector<T> > Matrix)
{
    unsigned row = Matrix.size();
    unsigned col = Matrix[0].size();
    vector<vector<T> > Trans(col,vector<T>(row,0));
    for (unsigned i = 0; i < col; ++i)
    {
        for (unsigned j = 0; j < row; ++j)
        {
            Trans[i][j] = Matrix[j][i];
        }
    }
    return Trans;
}

template <typename T>
T findMiddleValue(vector<T> vec)
{
    sort(vec.begin(),vec.end());
    auto pos = vec.size() / 2;
    return vec[pos];
}


//構建kd樹
void buildKdTree(KdTree* tree, vector<vector<double> > data, unsigned depth)
{

    //樣本的數量
    unsigned samplesNum = data.size();
    //終止條件
    if (samplesNum == 0)
    {
        return;
    }
    if (samplesNum == 1)
    {
        tree->root = data[0];
        return;
    }
    //樣本的維度
    unsigned k = data[0].size();
    vector<vector<double> > transData = Transpose(data);
    //選擇切分屬性
    unsigned splitAttribute = depth % k;
    vector<double> splitAttributeValues = transData[splitAttribute];
    //選擇切分值
    double splitValue = findMiddleValue(splitAttributeValues);
    //cout << "splitValue" << splitValue  << endl;

    // 根據選定的切分屬性和切分值,將資料集分為兩個子集
    vector<vector<double> > subset1;
    vector<vector<double> > subset2;
    for (unsigned i = 0; i < samplesNum; ++i)
    {
        if (splitAttributeValues[i] == splitValue && tree->root.empty())
            tree->root = data[i];
        else
        {
            if (splitAttributeValues[i] < splitValue)
                subset1.push_back(data[i]);
            else
                subset2.push_back(data[i]);
        }
    }

    //子集遞迴呼叫buildKdTree函式

    tree->leftChild = new KdTree;
    tree->leftChild->parent = tree;
    tree->rightChild = new KdTree;
    tree->rightChild->parent = tree;
    buildKdTree(tree->leftChild, subset1, depth + 1);
    buildKdTree(tree->rightChild, subset2, depth + 1);
}

//逐層列印kd樹
void printKdTree(KdTree *tree, unsigned depth)
{
    for (unsigned i = 0; i < depth; ++i)
        cout << "\t";
            
    for (vector<double>::size_type j = 0; j < tree->root.size(); ++j)
        cout << tree->root[j] << ",";
    cout << endl;
    if (tree->leftChild == NULL && tree->rightChild == NULL )//葉子節點
        return;
    else //非葉子節點
    {
        if (tree->leftChild != NULL)
        {
            for (unsigned i = 0; i < depth + 1; ++i)
                cout << "\t";
            cout << " left:";
            printKdTree(tree->leftChild, depth + 1);
        }
            
        cout << endl;
        if (tree->rightChild != NULL)
        {
            for (unsigned i = 0; i < depth + 1; ++i)
                cout << "\t";
            cout << "right:";
            printKdTree(tree->rightChild, depth + 1);
        }
        cout << endl;
    }
}


//計算空間中兩個點的距離
double measureDistance(vector<double> point1, vector<double> point2, unsigned method)
{
    if (point1.size() != point2.size())
    {
        cerr << "Dimensions don't match!!" ;
        exit(1);
    }
    switch (method)
    {
        case 0://歐氏距離
            {
                double res = 0;
                for (vector<double>::size_type i = 0; i < point1.size(); ++i)
                {
                    res += pow((point1[i] - point2[i]), 2);
                }
                return sqrt(res);
            }
        case 1://曼哈頓距離
            {
                double res = 0;
                for (vector<double>::size_type i = 0; i < point1.size(); ++i)
                {
                    res += abs(point1[i] - point2[i]);
                }
                return res;
            }
        default:
            {
                cerr << "Invalid method!!" << endl;
                return -1;
            }
    }
}
//在kd樹tree中搜索目標點goal的最近鄰
//輸入:目標點;已構造的kd樹
//輸出:目標點的最近鄰
vector<double> searchNearestNeighbor(vector<double> goal, KdTree *tree)
{
    /*第一步:在kd樹中找出包含目標點的葉子結點:從根結點出發,
    遞迴的向下訪問kd樹,若目標點的當前維的座標小於切分點的
    座標,則移動到左子結點,否則移動到右子結點,直到子結點為
    葉結點為止,以此葉子結點為“當前最近點”
    */
    unsigned k = tree->root.size();//計算出資料的維數
    unsigned d = 0;//維度初始化為0,即從第1維開始
    KdTree* currentTree = tree;
    vector<double> currentNearest = currentTree->root;
    while(!currentTree->isLeaf())
    {
        unsigned index = d % k;//計算當前維
        if (currentTree->rightChild->isEmpty() || goal[index] < currentNearest[index])
        {
            currentTree = currentTree->leftChild;
        }
        else
        {
            currentTree = currentTree->rightChild;
        }
        ++d;
    }
    currentNearest = currentTree->root;

    /*第二步:遞迴地向上回退, 在每個結點進行如下操作:
    (a)如果該結點儲存的例項比當前最近點距離目標點更近,則以該例點為“當前最近點”
    (b)當前最近點一定存在於某結點一個子結點對應的區域,檢查該子結點的父結點的另
    一子結點對應區域是否有更近的點(即檢查另一子結點對應的區域是否與以目標點為球
    心、以目標點與“當前最近點”間的距離為半徑的球體相交);如果相交,可能在另一
    個子結點對應的區域記憶體在距目標點更近的點,移動到另一個子結點,接著遞迴進行最
    近鄰搜尋;如果不相交,向上回退*/

    //當前最近鄰與目標點的距離
    double currentDistance = measureDistance(goal, currentNearest, 0);

    //如果當前子kd樹的根結點是其父結點的左孩子,則搜尋其父結點的右孩子結點所代表
    //的區域,反之亦反
    KdTree* searchDistrict;
    if (currentTree->isLeft())
    {
        if (currentTree->parent->rightChild == NULL)
            searchDistrict = currentTree;
        else
            searchDistrict = currentTree->parent->rightChild;
    }
    else
    {
        searchDistrict = currentTree->parent->leftChild;
    }

    //如果搜尋區域對應的子kd樹的根結點不是整個kd樹的根結點,繼續回退搜尋
    while (searchDistrict->parent != NULL)
    {
        //搜尋區域與目標點的最近距離
        double districtDistance = abs(goal[(d+1)%k] - searchDistrict->parent->root[(d+1)%k]);

        //如果“搜尋區域與目標點的最近距離”比“當前最近鄰與目標點的距離”短,表明搜尋
        //區域內可能存在距離目標點更近的點
        if (districtDistance < currentDistance )//&& !searchDistrict->isEmpty()
        {

            double parentDistance = measureDistance(goal, searchDistrict->parent->root, 0);

            if (parentDistance < currentDistance)
            {
                currentDistance = parentDistance;
                currentTree = searchDistrict->parent;
                currentNearest = currentTree->root;
            }
            if (!searchDistrict->isEmpty())
            {
                double rootDistance = measureDistance(goal, searchDistrict->root, 0);
                if (rootDistance < currentDistance)
                {
                    currentDistance = rootDistance;
                    currentTree = searchDistrict;
                    currentNearest = currentTree->root;
                }
            }
            if (searchDistrict->leftChild != NULL)
            {
                double leftDistance = measureDistance(goal, searchDistrict->leftChild->root, 0);
                if (leftDistance < currentDistance)
                {
                    currentDistance = leftDistance;
                    currentTree = searchDistrict;
                    currentNearest = currentTree->root;
                }
            }
            if (searchDistrict->rightChild != NULL)
            {
                double rightDistance = measureDistance(goal, searchDistrict->rightChild->root, 0);
                if (rightDistance < currentDistance)
                {
                    currentDistance = rightDistance;
                    currentTree = searchDistrict;
                    currentNearest = currentTree->root;
                }
            }
        }//end if

        if (searchDistrict->parent->parent != NULL)
        {
            searchDistrict = searchDistrict->parent->isLeft()?
                            searchDistrict->parent->parent->rightChild:
                            searchDistrict->parent->parent->leftChild;
        }
        else
        {
            searchDistrict = searchDistrict->parent;
        }
        ++d;
    }//end while
    return currentNearest;
}

int main()
{
    vector<vector<double> > train(6, vector<double>(2, 0));
    for (unsigned i = 0; i < 6; ++i)
        for (unsigned j = 0; j < 2; ++j)
            train[i][j] = data[i][j];

    KdTree* kdTree = new KdTree;
    buildKdTree(kdTree, train, 0);

    printKdTree(kdTree, 0);

    vector<double> goal;
    goal.push_back(3);
    goal.push_back(4.5);
    vector<double> nearestNeighbor = searchNearestNeighbor(goal, kdTree);
    vector<double>::iterator beg = nearestNeighbor.begin();
    cout << "The nearest neighbor is: ";
    while(beg != nearestNeighbor.end()) cout << *beg++ << ",";
    cout << endl;
    return 0;
}

相關推薦

統計學習方法筆記(一):K近鄰實現:kd樹

  實現k近鄰演算法時,首要考慮的問題是如何對訓練資料進行快速的k近鄰搜尋。這點在特徵空間的維數大於訓練資料容量時尤為重要。 構造kd樹   kd 樹是一種對k為空間中的例項點進行儲存的一邊對其進行快速檢索的樹形資料結構。kd樹是二叉樹,表示對k維空間的一個劃分(parti

機器學習基礎(四十三)—— kd 樹( k 近鄰實現

實現 k 近鄰法時,主要考慮的問題是如何對訓練資料進行快速 k 近鄰搜尋,這點在如下的兩種情況時,顯得尤為必要: (1)特徵空間的維度大 (2)訓練資料的容量很大時 k 近鄰法的最簡單的實現是現行掃描(linear scan),這時需計算輸入例項與每一個

k近鄰C++實現

#include <iostream> #include <vector> #include <algorithm> #include <string> #include <cmath> using namespace std; struct KdT

統計學習三:2.K近鄰代碼實現(以最近鄰為例)

數據集 learning pytho port 4.3 @property 存儲 uil github 通過上文可知感知機模型的基本原理,以及算法的具體流程。本文實現了感知機模型算法的原始形式,通過對算法的具體實現,我們可以對算法有進一步的了解。具體代碼可以在我的githu

機器學習系列:k 近鄰k-NN)的原理及實現

  本內容將介紹機器學習中的 k k k 近鄰法(

統計學習方法c++實現之二 k近鄰

統計學習方法c++實現之二 k近鄰演算法 前言 k近鄰演算法可以說概念上很簡單,即:“給定一個訓練資料集,對新的輸入例項,在訓練資料集中找到與這個例項最鄰近的k個例項,這k個例項的多數屬於某個類,就把該輸入分為這個類。”其中我認為距離度量最關鍵,但是距離度量的方法也很簡單,最長用的就是歐氏距離,其他的距離

k近鄰:R實現(一)

KNN是有監督的學習演算法,其特點有: 1、精度高,對異常值不敏感 2、只能處理數值型屬性 3、計算複雜度高(如已知分類的樣本數為n,那麼對每個未知分類點要計算n個距離) KNN演算法步驟: 需對所有樣本點(已知分類+未知分類)進行歸一化處理。 然後,對未知分類的資料

K近鄰之kd樹及其Python實現

作為機器學習中一種基本的分類方法,K近鄰(KNN)法是一種相對簡單的方法。其中一個理由是K近鄰法不需要對訓練集進行學習。然而,不需要對訓練集進行學習,反過來也會造成對測試集進行判定時,計算與空間複雜度的增加。 K近鄰法最簡單的實現方法是對需要分類的目標點,計算出訓練集中每一

AVLTree的實現(C++實現)

pen nod util ron bool allocator cti tor utili #include<stack>#include<utility>#include<allocators>#include<functiona

K近鄰

數據集 量化 學習過程 要求 過程 nbsp k近鄰 實例 數據   K近鄰法是機器學習所有算法中理論最簡單,最好理解的算法。它是一種基本的分類與回歸方法,它的輸入為實例的特征向量,通過計算新數據與訓練數據特征值之間的距離,然後選取K(K>=1)個距離最近的鄰居進行分

基於私鑰加密公鑰解密的RSA算C#實現方法

第一個 inter tro 十進制 函數 軟件 產生 ++ 原創 本文實例講述了基於私鑰加密公鑰解密的RSA算法C#實現方法,是一種應用十分廣泛的算法。分享給大家供大家參考之用。具體方法如下: 一、概述 RSA算法是第一個能同時用於加密和數字簽名的算法,也易於理解和操

[劍指offer] 最小的K個數,C++實現

urn bubuko 存儲 best 9.png clas master 代碼 wan 原創博文,轉載請註明出處!github地址# 題目 輸入n個整數,找出其中最小的K個數。例如輸入4,5,1,6,2,7,3,8這8個數字,則最小的4個數字是1,2,3,4# 思

《統計學習方法》筆記三 k近鄰

學習 屬於 基本 mage 容易 向量 規則 統計學 圖片 k近鄰是一種基本分類與回歸方法,書中只討論分類情況。輸入為實例的特征向量,輸出為實例的類別。k值的選擇、距離度量及分類決策規則是k近鄰法的三個基本要素。 k近鄰算法 給定一個訓練數據集,對新的輸入實例,在訓練數

機器學習實戰——k-近鄰演算法Python實現問題記錄

  準備 kNN.py 的python模組 from numpy import * import operator def createDataSet(): group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])

第三章k近鄰(接上篇)

3.3k近鄰法的實現:kd樹 3.3.1構造kd樹, (1)構造跟節點,以訓練集T中的一維度的中位點作為切分點,將超矩形區域劃分為兩部分, (2)重複:對深度為j的節點選擇切分座標的中位值, (3)直到子區域沒有例項存在為止,從而形成kd樹的劃分 3.3.2搜尋kd樹 用kd樹進行最近鄰

【統計學習方法-李航-筆記總結】三、k近鄰

本文是李航老師《統計學習方法》第三章的筆記,歡迎大佬巨佬們交流。 主要參考部落格:https://blog.csdn.net/u013358387/article/details/53327110 主要包括以下幾部分: 1. k近鄰演算法 2. k近鄰模型 3. kd樹 1.

資料結構與演算法之列舉(窮舉) C++實現

列舉法的本質就是從所有候選答案中去搜索正確的解,使用該演算法需要滿足兩個條件: 1、可以先確定候選答案的數量; 2、候選答案的範圍在求解之前必須是一個確定的集合。 列舉是最簡單,最基礎,也是最沒效率的演算法 列舉法優點: 1、列舉有超級無敵準確性,只要時間足夠,正確的列舉得出的結

列主元Gauss消去(C++實現)

列主元Gauss消去法(C++) 目的:編寫解n階線性方程組AX=b的列主元三角分解法的通用程式; 原理:列主元素消去法是為控制舍入誤差而提出來的一種演算法,列主元素消去法計算基本上能控制舍入誤差的影響,其基本思想是:在進行第 k(k=1,2,...,n-1)步消元時,從第k列的 akk及其

K近鄰(KNN)原理小結

tel .get ack 索引 觀察 運用 oob import port    一、緒論    K近鄰法(k-nearest neighbors,KNN)是一種很基本的機器學習方法了,在我們平常的生活中也會不自主的應用。比如,我們判斷一個人的人品,只需要觀察他來往最密切的

K近鄰-k-nearest neighbor,KNN

WIKI In pattern recognition, the k-nearest neighbors algorithm (k-NN) is a non-parametric method used for classification and regression.[