1. 程式人生 > 實用技巧 >opencv筆記--Kmeans

opencv筆記--Kmeans

在影象分割中,使用 kmeans 演算法可以實現影象區域基本分割。如果一幅影象被分為兩類,kmeans 分割效果與 ostu 演算法基本一致,具體如下圖:

kmeans 將影象灰度聚類為 k 類,ostu 將影象灰度分割為 2 類,當 k = 2 時,兩種演算法最終目的基本趨於一致。

kmeans 演算法基本思路如下:

1)隨機選取第一個聚類中心點,之後的聚類中心點選取有兩種方法;

a. 隨機選取其他 k - 1 個聚類中心點;

b. 根據已經選取的聚類中心點,計算所有點到已經選取的聚類中心點的距離,選擇到所有已經選取的聚類中心點的最遠點作為下一個聚類中心點;

2)根據點到已經選取的聚類中心點的距離對其進行分類;

3)重新求各個分類的聚類中心點,然後回到 2);

4)當不再滿足迭代條件時給出最終聚類結果,迭代條件包括:

a. 聚類中心點在迭代過程中的偏移量;

b. 迭代次數;

對於聚類中心點的選擇,一般情況下,方法 b 會得到更好的聚類,且迭代速度較快。

opencv 提供的 kmean 函式為:

double kmeans( InputArray data, int K, InputOutputArray bestLabels,TermCriteria criteria, int attempts,

int flags, OutputArray centers=noArray() );

引數如下:

data: 待分類點矩陣,其型別必須為CV_32F;

K,bestLabels: 聚類數與待分類點所屬分類;

criteria:停止條件;

attempts:使用不同的隨機聚類中心點嘗試聚類次數;

flags:聚類中心點選擇方案,包括完全隨機選擇,kmeans++選擇方案(b),使用者輸入;

centers:最終聚類中心點;

以下給出 kmeans 演算法使用程式碼:

 1 void UseKmeans(cv::Mat& src, cv::Mat& rst)
 2 {
 3     int width = src.cols;
4 int height = src.rows; 5 int dims = src.channels(); 6 int sampleCount = width * height; 7 8 int clusterCount = 2; 9 Mat points(sampleCount, dims, CV_32F, Scalar(10)); 10 cv::Mat pos(sampleCount, 2, CV_16S, Scalar(0, 0)); 11 Mat labels; 12 Mat centers(clusterCount, 1, points.type()); 13 14 // invert to data points 15 int index = 0; 16 for (int row = 0; row < height; row++) { 17 for (int col = 0; col < width; col++) { 18 points.at<float>(index, 0) = static_cast<int>(src.ptr<uchar>(row)[col]); 19 pos.at<short>(index, 0) = static_cast<short>(row); 20 pos.at<short>(index, 1) = static_cast<int>(col); 21 ++index; 22 } 23 } 24 25 // k-mean algorithm 26 TermCriteria criteria = TermCriteria(CV_TERMCRIT_EPS + CV_TERMCRIT_ITER, 100, 1.0); 27 kmeans(points, clusterCount, labels, criteria, 3, KMEANS_PP_CENTERS, centers); 28 29 int bright_val = -1; 30 for (int i = 0; i < centers.rows; ++i) 31 { 32 int val = centers.at<float>(i, 0); 33 if (val > bright_val) 34 bright_val = val; 35 } 36 37 int bright_label = -1; 38 for (int idx = 0; idx < sampleCount; ++idx) 39 { 40 float *datapoint = points.ptr<float>(idx); 41 int *datalabel = labels.ptr<int>(idx); 42 if (datapoint[0] >= bright_val) 43 { 44 bright_label = datalabel[0]; 45 break; 46 } 47 } 48 49 // save result 50 rst.create(src.size(), CV_8UC1); 51 rst.rowRange(0, rst.rows) = 0; 52 for (int idx = 0; idx < sampleCount; ++idx) 53 { 54 int *datalabel = labels.ptr<int>(idx); 55 if (datalabel[0] == bright_label) 56 { 57 int row = pos.at<short>(idx, 0); 58 int col = pos.at<short>(idx, 1); 59 rst.ptr<uchar>(row)[col] = 255; 60 } 61 } 62 }