opencv2.4.9中KNN演算法理解
阿新 • • 發佈:2019-02-06
KNN演算法
- opencv中文版原文描述是:K近鄰可能是最簡單的分類器。訓練資料跟類別標籤放在一起,離測試資料最近的(歐氏距離最近)K個樣本進行投票,確定測試資料的分類結果。這可能是你想到的最賤的方法,該方法比較有效,但是速度比較慢且對記憶體的需求比較大(因為它需要儲存所有訓練集)。
- 離測試資料是否最近要求計算測試樣本與所有點的距離,可以把這個過程看成搜尋過程,是個求top-K的問題。由於要進行投票,那麼就可以分為帶權投票還是不帶權投票,不帶權投票可能會出現近鄰中不足K個或K個距離和非常大,從而使得分類出錯。另外,引數K該怎麼取應該就很關鍵了,我看到的大部分程式時人為規定引數K值,應該存在自動確定K值的演算法。K近鄰中的K含義是在將測試樣本分到某個類別時要知道其K個近鄰樣本的類別,把測試樣本歸類到樣本佔多數的類別,這說明在分類新樣本之前需要至少有K個樣本且知道其類別,因此可以說KNN是監督學習演算法。
- Opencv2.4.9中實現的KNN是在CvKNearest類中,繼承於CvStatModel,放在了機器學習模組。既可以用來進行分類,又可以用來做迴歸,並且支援增量學習,可以用新樣本來更新模型。但是opencv的KNN演算法不像決策樹一樣支援變數子集選擇和屬性缺失的情況。
-使用方法可以採用定義CvKNearest類的物件方法如下:
CvKNearest knn(trainData, trainClasses, 0, false, K);
,該方法會在建構函式中執行訓練過程,訓練的過程就是給CvMat *smaple賦予資料的過程,將樣本有序的放在記憶體的前面,label放在記憶體的後面。構造程式碼如下:
CvKNearest::CvKNearest( const CvMat* _train_data, const CvMat* _responses,
const CvMat* _sample_idx, bool _is_regression, int _max_k )
{
samples = 0;
train( _train_data, _responses, _sample_idx, _is_regression, _max_k, false );
}
- 來了一個新樣本只需要呼叫response = knn.find_nearest(&sample, K, 0, 0, nearests,
0);函式就可以得到類別了。
opencv中的實現過程
假如已經將一批帶標籤的資料通過建構函式的方式存入到CvKNearest的成員變數samples所指記憶體區,現在來了一個新的測試樣本,首先計算測試樣本與所有訓練樣本的歐氏距離,呼叫find_neighbors_direct計算所有測試樣本到所有訓練集的距離。
void CvKNearest::find_neighbors_direct( const CvMat* _samples, int k, int start, int end,
float* neighbor_responses, const float** neighbors, float* dist ) const
{
int i, j, count = end - start, k1 = 0, k2 = 0, d = var_count;
CvVectors* s = samples;
for( ; s != 0; s = s->next )
{
int n = s->count;//訓練集中樣本個數
for( j = 0; j < n; j++ )
{
for( i = 0; i < count; i++ )//count表示測試樣本個數,無標籤
{
double sum = 0;
Cv32suf si;
const float* v = s->data.fl[j];//指向已帶標籤樣本j
const float* u = (float*)(_samples->data.ptr + _samples->step*(start + i));//指向測試樣本
Cv32suf* dd = (Cv32suf*)(dist + i*k);//儲存樣本i的到k個近鄰的距離(Cv32suf的f域)和近鄰的類別(Cv32suf的i域)
float* nr;
const float** nn;
int t, ii, ii1;
for( t = 0; t <= d - 4; t += 4 )
{//計算樣本i與樣本j的歐式距離,最小維數為4
double t0 = u[t] - v[t], t1 = u[t+1] - v[t+1];
double t2 = u[t+2] - v[t+2], t3 = u[t+3] - v[t+3];
sum += t0*t0 + t1*t1 + t2*t2 + t3*t3;
}
for( ; t < d; t++ )
{//計算樣本i與樣本j的歐式距離,維數小於4
double t0 = u[t] - v[t];
sum += t0*t0;
}
si.f = (float)sum;
for( ii = k1-1; ii >= 0; ii-- )//將類別i從小到大排序,插入到ii+1位置
if( si.i > dd[ii].i )
break;
if( ii >= k-1 )
continue;
nr = neighbor_responses + i*k;
nn = neighbors ? neighbors + (start + i)*k : 0;
for( ii1 = k2 - 1; ii1 > ii; ii1-- )//插入前資料後移
{
dd[ii1+1].i = dd[ii1].i;
nr[ii1+1] = nr[ii1];
if( nn ) nn[ii1+1] = nn[ii1];
}
dd[ii+1].i = si.i;//給ii+1位置的樣本賦予插入的樣本的類別編號,由於是union結構,距離也儲存了
nr[ii+1] = ((float*)(s + 1))[j];
if( nn )
nn[ii+1] = v;
}
k1 = MIN( k1+1, k );
k2 = MIN( k1, k-1 );
}
}
}
得多所有距離後,採用氣泡排序方式得到topK近鄰,並計算樣本數最多的類別。
float CvKNearest::write_results( int k, int k1, int start, int end,
const float* neighbor_responses, const float* dist,
CvMat* _results, CvMat* _neighbor_responses,
CvMat* _dist, Cv32suf* sort_buf ) const
{
float result = 0.f;
int i, j, j1, count = end - start;
double inv_scale = 1./k1;
int rstep = _results && !CV_IS_MAT_CONT(_results->type) ? _results->step/sizeof(result) : 1;
for( i = 0; i < count; i++ )//count=1
{
const Cv32suf* nr = (const Cv32suf*)(neighbor_responses + i*k);
float* dst;
float r;
if( _results || start+i == 0 )
{
if( regression )
{//不執行
double s = 0;
for( j = 0; j < k1; j++ )
s += nr[j].f;
r = (float)(s*inv_scale);
}
else
{
int prev_start = 0, best_count = 0, cur_count;
Cv32suf best_val;
for( j = 0; j < k1; j++ )//複製前K1個數據
sort_buf[j].i = nr[j].i;
for( j = k1-1; j > 0; j-- )//表示排序次數k1-1次,使類別標籤有序
{
bool swap_fl = false;
for( j1 = 0; j1 < j; j1++ )//c從前往後,比較淺j個數
if( sort_buf[j1].i > sort_buf[j1+1].i )//從小到大排序
{
int t;
CV_SWAP( sort_buf[j1].i, sort_buf[j1+1].i, t );
swap_fl = true;
}
if( !swap_fl )//如果已經有序,則跳出迴圈
break;
}
best_val.i = 0;//記錄樣本數最多的類別
for( j = 1; j <= k1; j++ )//氣泡排序k1-1次
if( j == k1 || sort_buf[j].i != sort_buf[j-1].i )
{//遇到新的類別
cur_count = j - prev_start;//類別計算
if( best_count < cur_count )
{
best_count = cur_count;
best_val.i = sort_buf[j-1].i;
}
prev_start = j;
}
r = best_val.f;
}
if( start+i == 0 )
result = r;
if( _results )
_results->data.fl[(start + i)*rstep] = r;
}
if( _neighbor_responses )
{
dst = (float*)(_neighbor_responses->data.ptr +
(start + i)*_neighbor_responses->step);
for( j = 0; j < k1; j++ )
dst[j] = nr[j].f;//從nr到dst
for( ; j < k; j++ )
dst[j] = 0.f;
}
if( _dist )
{
dst = (float*)(_dist->data.ptr + (start + i)*_dist->step);
for( j = 0; j < k1; j++ )
dst[j] = dist[j + i*k];
for( ; j < k; j++ )
dst[j] = 0.f;
}
}
return result;
}