1. 程式人生 > >opencv2.4.9中KNN演算法理解

opencv2.4.9中KNN演算法理解

KNN演算法

  • opencv中文版原文描述是:K近鄰可能是最簡單的分類器。訓練資料跟類別標籤放在一起,離測試資料最近的(歐氏距離最近)K個樣本進行投票,確定測試資料的分類結果。這可能是你想到的最賤的方法,該方法比較有效,但是速度比較慢且對記憶體的需求比較大(因為它需要儲存所有訓練集)。
  • 離測試資料是否最近要求計算測試樣本與所有點的距離,可以把這個過程看成搜尋過程,是個求top-K的問題。由於要進行投票,那麼就可以分為帶權投票還是不帶權投票,不帶權投票可能會出現近鄰中不足K個或K個距離和非常大,從而使得分類出錯。另外,引數K該怎麼取應該就很關鍵了,我看到的大部分程式時人為規定引數K值,應該存在自動確定K值的演算法。K近鄰中的K含義是在將測試樣本分到某個類別時要知道其K個近鄰樣本的類別,把測試樣本歸類到樣本佔多數的類別,這說明在分類新樣本之前需要至少有K個樣本且知道其類別,因此可以說KNN是監督學習演算法。
  • Opencv2.4.9中實現的KNN是在CvKNearest類中,繼承於CvStatModel,放在了機器學習模組。既可以用來進行分類,又可以用來做迴歸,並且支援增量學習,可以用新樣本來更新模型。但是opencv的KNN演算法不像決策樹一樣支援變數子集選擇和屬性缺失的情況。
    -使用方法可以採用定義CvKNearest類的物件方法如下:
CvKNearest knn(trainData, trainClasses, 0, false, K);
 ,該方法會在建構函式中執行訓練過程,訓練的過程就是給CvMat *smaple賦予資料的過程,將樣本有序的放在記憶體的前面,label放在記憶體的後面。構造程式碼如下:
CvKNearest::CvKNearest( const CvMat* _train_data, const CvMat* _responses,
                        const CvMat* _sample_idx, bool _is_regression, int _max_k )
{
    samples = 0;
    train( _train_data, _responses, _sample_idx, _is_regression, _max_k, false );
}
  • 來了一個新樣本只需要呼叫response = knn.find_nearest(&sample, K, 0, 0, nearests,
    0);函式就可以得到類別了。

opencv中的實現過程

假如已經將一批帶標籤的資料通過建構函式的方式存入到CvKNearest的成員變數samples所指記憶體區,現在來了一個新的測試樣本,首先計算測試樣本與所有訓練樣本的歐氏距離,呼叫find_neighbors_direct計算所有測試樣本到所有訓練集的距離。

void CvKNearest::find_neighbors_direct( const CvMat* _samples, int k, int start, int end,
                    float* neighbor_responses, const float** neighbors, float* dist ) const
{
    int i, j, count = end - start, k1 = 0, k2 = 0, d = var_count;
    CvVectors* s = samples;

    for( ; s != 0; s = s->next )
    {
        int n = s->count;//訓練集中樣本個數
        for( j = 0; j < n; j++ )
        {
            for( i = 0; i < count; i++ )//count表示測試樣本個數,無標籤
            {
                double sum = 0;
                Cv32suf si;
                const float* v = s->data.fl[j];//指向已帶標籤樣本j
                const float* u = (float*)(_samples->data.ptr + _samples->step*(start + i));//指向測試樣本
                Cv32suf* dd = (Cv32suf*)(dist + i*k);//儲存樣本i的到k個近鄰的距離(Cv32suf的f域)和近鄰的類別(Cv32suf的i域)
                float* nr;
                const float** nn;
                int t, ii, ii1;

                for( t = 0; t <= d - 4; t += 4 )
                {//計算樣本i與樣本j的歐式距離,最小維數為4
                    double t0 = u[t] - v[t], t1 = u[t+1] - v[t+1];
                    double t2 = u[t+2] - v[t+2], t3 = u[t+3] - v[t+3];
                    sum += t0*t0 + t1*t1 + t2*t2 + t3*t3;
                }

                for( ; t < d; t++ )
                {//計算樣本i與樣本j的歐式距離,維數小於4
                    double t0 = u[t] - v[t];
                    sum += t0*t0;
                }

                si.f = (float)sum;
                for( ii = k1-1; ii >= 0; ii-- )//將類別i從小到大排序,插入到ii+1位置
                    if( si.i > dd[ii].i )
                        break;
                if( ii >= k-1 )
                    continue;

                nr = neighbor_responses + i*k;
                nn = neighbors ? neighbors + (start + i)*k : 0;
                for( ii1 = k2 - 1; ii1 > ii; ii1-- )//插入前資料後移
                {
                    dd[ii1+1].i = dd[ii1].i;
                    nr[ii1+1] = nr[ii1];
                    if( nn ) nn[ii1+1] = nn[ii1];
                }
                dd[ii+1].i = si.i;//給ii+1位置的樣本賦予插入的樣本的類別編號,由於是union結構,距離也儲存了
                nr[ii+1] = ((float*)(s + 1))[j];
                if( nn )
                    nn[ii+1] = v;
            }
            k1 = MIN( k1+1, k );
            k2 = MIN( k1, k-1 );
        }
    }
}

得多所有距離後,採用氣泡排序方式得到topK近鄰,並計算樣本數最多的類別。

float CvKNearest::write_results( int k, int k1, int start, int end,
    const float* neighbor_responses, const float* dist,
    CvMat* _results, CvMat* _neighbor_responses,
    CvMat* _dist, Cv32suf* sort_buf ) const
{
    float result = 0.f;
    int i, j, j1, count = end - start;
    double inv_scale = 1./k1;
    int rstep = _results && !CV_IS_MAT_CONT(_results->type) ? _results->step/sizeof(result) : 1;

    for( i = 0; i < count; i++ )//count=1
    {
        const Cv32suf* nr = (const Cv32suf*)(neighbor_responses + i*k);
        float* dst;
        float r;
        if( _results || start+i == 0 )
        {
            if( regression )
            {//不執行
                double s = 0;
                for( j = 0; j < k1; j++ )
                    s += nr[j].f;
                r = (float)(s*inv_scale);
            }
            else
            {
                int prev_start = 0, best_count = 0, cur_count;
                Cv32suf best_val;

                for( j = 0; j < k1; j++ )//複製前K1個數據
                    sort_buf[j].i = nr[j].i;

                for( j = k1-1; j > 0; j-- )//表示排序次數k1-1次,使類別標籤有序
                {
                    bool swap_fl = false;
                    for( j1 = 0; j1 < j; j1++ )//c從前往後,比較淺j個數
                        if( sort_buf[j1].i > sort_buf[j1+1].i )//從小到大排序
                        {
                            int t;
                            CV_SWAP( sort_buf[j1].i, sort_buf[j1+1].i, t );
                            swap_fl = true;
                        }
                    if( !swap_fl )//如果已經有序,則跳出迴圈
                        break;
                }

                best_val.i = 0;//記錄樣本數最多的類別
                for( j = 1; j <= k1; j++ )//氣泡排序k1-1次
                    if( j == k1 || sort_buf[j].i != sort_buf[j-1].i )
                    {//遇到新的類別
                        cur_count = j - prev_start;//類別計算
                        if( best_count < cur_count )
                        {
                            best_count = cur_count;
                            best_val.i = sort_buf[j-1].i;
                        }
                        prev_start = j;
                    }
                r = best_val.f;
            }

            if( start+i == 0 )
                result = r;

            if( _results )
                _results->data.fl[(start + i)*rstep] = r;
        }

        if( _neighbor_responses )
        {
            dst = (float*)(_neighbor_responses->data.ptr +
                (start + i)*_neighbor_responses->step);
            for( j = 0; j < k1; j++ )
                dst[j] = nr[j].f;//從nr到dst
            for( ; j < k; j++ )
                dst[j] = 0.f;
        }

        if( _dist )
        {
            dst = (float*)(_dist->data.ptr + (start + i)*_dist->step);
            for( j = 0; j < k1; j++ )
                dst[j] = dist[j + i*k];
            for( ; j < k; j++ )
                dst[j] = 0.f;
        }
    }

    return result;
}