1. 程式人生 > >學習OpenCV2——MeanShift之圖形分割

學習OpenCV2——MeanShift之圖形分割

1. 原理

    用meanshift做影象平滑和分割,其實是一回事。其本質是經過迭代,將收斂點的畫素值代替原來的畫素值,從而去除了區域性相似的紋理,同時保留了邊緣等差異較大的特徵。


        OpenCV中自帶有基於meanshift的分割方法pyrMeanShiftFiltering()。由函式名pyrMeanShiftFiltering可知,這裡是將meanshift演算法和影象金字塔相結合用來分割的。

<span style="font-size:18px;">void PyrMeanShiftFiltering( const CvArr* srcarr,          //輸入影象
				    CvArr* dstarr,        //輸出影象
				   double  sp,            //顏色域半徑
				    double sr,            //空間域半徑
				       int max_level,     //金字塔最大層數                    
			    CvTermCriteria termcrit )     //迭代終止條件</span>

    要求輸入和輸出影象都是CV_8UC3型別,而且兩者尺寸一樣。實際上並不需要去先定義dstarr,因為程式裡會將srcarr的格式賦值給dstarr。

    termcrit有三種情況,迭代次數、迭代精度和兩者同時滿足。預設為迭代次數為5同時迭代精度為1。termcrit是個結構體,其結構如下

<span style="font-size:18px;">typedef struct CvTermCriteria
{
    int    type;        /*CV_TERMCRIT_ITER或CV_TERMCRIT_EPS 或二者都是*/
    int    max_iter;   /* 最大迭代次數 */
    double epsilon;    /* 結果的精確性 */
}
CvTermCriteria;</span>
     使用pyrMeanShiftFiltering()進行影象分割非常簡單,只需要定義sp0,sr,max_level和termrit,然後呼叫pyrMeanShiftFiltering()就行了。

    在實際操作時,為了使分割的結果顯示得更明顯,經常用floodFill( )將不同連通域塗上不同的顏色。具體情況參看下 面的例項。

2. 程式例項

    來看看OpenCV自帶的一個用meanshift進行分割的例子

    原程式見   “  .\OpenCV249\sources\samples\cpp\meanshift_segmentation.cpp”

<span style="font-size:18px;">#include "opencv2/highgui/highgui.hpp"
#include "opencv2/core/core.hpp"
#include "opencv2/imgproc/imgproc.hpp"

#include <iostream>

using namespace cv;
using namespace std;

static void help(char** argv)
{
    cout << "\nDemonstrate mean-shift based color segmentation in spatial pyramid.\n"
    << "Call:\n   " << argv[0] << " image\n"
    << "This program allows you to set the spatial and color radius\n"
    << "of the mean shift window as well as the number of pyramid reduction levels explored\n"
    << endl;
}

//This colors the segmentations
static void floodFillPostprocess( Mat& img, const Scalar& colorDiff=Scalar::all(1) )
{
    CV_Assert( !img.empty() );
    RNG rng = theRNG();
    Mat mask( img.rows+2, img.cols+2, CV_8UC1, Scalar::all(0) );
    for( int y = 0; y < img.rows; y++ )
    {
        for( int x = 0; x < img.cols; x++ )
        {
            if( mask.at<uchar>(y+1, x+1) == 0 )
            {
                Scalar newVal( rng(256), rng(256), rng(256) );
                floodFill( img, mask, Point(x,y), newVal, 0, colorDiff, colorDiff );
            }
        }
    }
}

string winName = "meanshift";
int spatialRad, colorRad, maxPyrLevel;
Mat img, res;

static void meanShiftSegmentation( int, void* )
{
    cout << "spatialRad=" << spatialRad << "; "
         << "colorRad=" << colorRad << "; "
         << "maxPyrLevel=" << maxPyrLevel << endl;
    pyrMeanShiftFiltering( img, res, spatialRad, colorRad, maxPyrLevel );
	//Mat imgGray;
	//cvtColor(res,imgGray,CV_RGB2GRAY);
	//imshow("res",res);
    floodFillPostprocess( res, Scalar::all(2) );
    imshow( winName, res );
}

int main(int argc, char** argv)
{    	
	img = imread("rubberwhale1.png");
	//img = imread("pic2.png");	
	
    
    if( img.empty() )
        return -1;

    spatialRad = 10;  
    colorRad = 10;
    maxPyrLevel = 1;

    namedWindow( winName, WINDOW_AUTOSIZE );
    //imshow("img",img);	


    createTrackbar( "spatialRad", winName, &spatialRad, 80, meanShiftSegmentation );
    createTrackbar( "colorRad", winName, &colorRad, 60, meanShiftSegmentation );
    createTrackbar( "maxPyrLevel", winName, &maxPyrLevel, 5, meanShiftSegmentation );

    meanShiftSegmentation(0, 0);
    //floodFillPostprocess( img, Scalar::all(2) );
    //imshow("img2",img);
    waitKey();
    return 0;
}</span>
程式很簡單,來看看floodFill()函式,有兩種形式
    int floodFill( InputOutputArray image, Point seedPoint, Scalar newVal, CV_OUT Rect* rect=0, Scalar loDiff=Scalar(), Scalar upDiff=Scalar(), int flags=4 );
    int floodFill( InputOutputArray image,  InputOutputArray mask, Point seedPoint,  Scalar newVal, 
CV_OUT Rect* rect=0,  Scalar loDiff=Scalar(),  Scalar upDiff=Scalar(),  int flags=4 );

     InputOutputArray image    輸入輸出影象,要求格式為1通道或3通道,8位或浮點

     InputOutputArray mask   掩膜,比image的寬和高各大兩畫素點

     Point seedPoint    填充的起始點

Scalar newVal   畫素點被染色的值

CV_OUT Rect* rect=0  可選引數,設定floodFill()要重繪區域的最小邊界矩形區域

Scalar loDiff=Scalar()  定義當前畫素值與起始點畫素值的亮度或顏色負差的最大值

Scalar upDiff=Scalar()  定義當前畫素值與起始點畫素值的亮度或顏色正差的最大值

flags 操作標誌符    

程式結果


    處理後一些細小的紋理都平滑掉了,例如圖中綠色線條所指示的區域。未填充時,很多地方看得並不明顯,填充後就能明顯看出差別來了。填充後的圖很好地體現了meanshift聚類的思想!

    再來看一組更“誇張”的效果圖


    使用meanshift方法進行處理後,原來的三個矩形區域消失了!平滑掉了!

    meanshift演算法的兩個關鍵引數是空間域半徑sr和顏色域半徑sp,別說max_level,那是構建影象金字塔的引數好吧。最後,我們來看看sr和sp對結果的影響。


       顯然顏色域半徑sp對結果的影響比空間域半徑sr對結果的影響大。sp和sr越小,細節保留得越多,sp和sr越大,平滑力度越大。邊緣和顏色突變的區域的特徵保留的較好。因為meanshift要對每個畫素點進行操作,所以演算法的時間花銷很大。

3. 深入程式碼

<span style="font-size:14px;">/****************************************************************************************\
*                                         Meanshift                                      *
\****************************************************************************************/

CV_IMPL void
cvPyrMeanShiftFiltering( const CvArr* srcarr, CvArr* dstarr,
                         double sp0, double sr, int max_level,
                         CvTermCriteria termcrit )
{
    const int cn = 3;
    const int MAX_LEVELS = 8;

    if( (unsigned)max_level > (unsigned)MAX_LEVELS )
        CV_Error( CV_StsOutOfRange, "The number of pyramid levels is too large or negative" );   //限定max_level不超過8

    std::vector<cv::Mat> src_pyramid(max_level+1);    //+1是因為原始圖和最終圖都定義為影象金字塔的第0層
    std::vector<cv::Mat> dst_pyramid(max_level+1);
    cv::Mat mask0;
    int i, j, level;
    //uchar* submask = 0;

    #define cdiff(ofs0) (tab[c0-dptr[ofs0]+255] + \
        tab[c1-dptr[(ofs0)+1]+255] + tab[c2-dptr[(ofs0)+2]+255] >= isr22)

    double sr2 = sr * sr;
    int isr2 = cvRound(sr2), isr22 = MAX(isr2,16);
    int tab[768];
    cv::Mat src0 = cv::cvarrToMat(srcarr);     //arr轉Mat
    cv::Mat dst0 = cv::cvarrToMat(dstarr);

    //確保src和dst都是CV_8UC3,且同尺寸
    if( src0.type() != CV_8UC3 )
        CV_Error( CV_StsUnsupportedFormat, "Only 8-bit, 3-channel images are supported" );
    if( src0.type() != dst0.type() )
        CV_Error( CV_StsUnmatchedFormats, "The input and output images must have the same type" );
    if( src0.size() != dst0.size() )
        CV_Error( CV_StsUnmatchedSizes, "The input and output images must have the same size" );

	//確保迭代次數在1到100次,預設則為5;迭代精度預設為1.
    if( !(termcrit.type & CV_TERMCRIT_ITER) )
        termcrit.max_iter = 5;
    termcrit.max_iter = MAX(termcrit.max_iter,1);
    termcrit.max_iter = MIN(termcrit.max_iter,100);
    if( !(termcrit.type & CV_TERMCRIT_EPS) )
        termcrit.epsilon = 1.f;
    termcrit.epsilon = MAX(termcrit.epsilon, 0.f);

    for( i = 0; i < 768; i++ )
        tab[i] = (i - 255)*(i - 255);  //tab[]存的是(-255)^2到512^2

    // 1. 構造金字塔
    src_pyramid[0] = src0;
    dst_pyramid[0] = dst0;
    for( level = 1; level <= max_level; level++ )
    {
		//src_pyramid和dst_pyramid尺寸一樣,下一層是上一層尺寸的一半
        src_pyramid[level].create( (src_pyramid[level-1].rows+1)/2,
                        (src_pyramid[level-1].cols+1)/2, src_pyramid[level-1].type() );
        dst_pyramid[level].create( src_pyramid[level].rows,
                        src_pyramid[level].cols, src_pyramid[level].type() );
		//對src_pyramid[level-1]下采樣,結果存入src_pyramid[level]
        cv::pyrDown( src_pyramid[level-1], src_pyramid[level], src_pyramid[level].size() );
        //CV_CALL( cvResize( src_pyramid[level-1], src_pyramid[level], CV_INTER_AREA ));
    }

    mask0.create(src0.rows, src0.cols, CV_8UC1);
    //CV_CALL( submask = (uchar*)cvAlloc( (sp+2)*(sp+2) ));

    // 2. 從頂層(最小層)開始應用meanshift演算法。
    for( level = max_level; level >= 0; level-- )
    {
        cv::Mat src = src_pyramid[level];
        cv::Size size = src.size();
        uchar* sptr = src.data;        //sptr指向影象矩陣的起始地址,也就是第一行的起始地址
        int sstep = (int)src.step;     //sstep是影象矩陣每一行的長度(以位元組為單位),以便後面計算地址
        uchar* mask = 0;
        int mstep = 0;
        uchar* dptr;
        int dstep;
        float sp = (float)(sp0 / (1 << level));   
        sp = MAX( sp, 1 );           //這裡保證了sp≥1,那麼視窗最小是3×3

		//這段語句主要作用1、通過上取樣得到dst_pyramid[level];2、得到掩碼mask
        if( level < max_level )
        {
            cv::Size size1 = dst_pyramid[level+1].size();
            cv::Mat m( size.height, size.width, CV_8UC1, mask0.data );
            dstep = (int)dst_pyramid[level+1].step;
            dptr = dst_pyramid[level+1].data + dstep + cn;
            mstep = (int)m.step;
            mask = m.data + mstep;
            //cvResize( dst_pyramid[level+1], dst_pyramid[level], CV_INTER_CUBIC );
            cv::pyrUp( dst_pyramid[level+1], dst_pyramid[level], dst_pyramid[level].size() ); //上取樣
            m.setTo(cv::Scalar::all(0));

            for( i = 1; i < size1.height-1; i++, dptr += dstep - (size1.width-2)*3, mask += mstep*2 )
            {
                for( j = 1; j < size1.width-1; j++, dptr += cn )
                {
                    int c0 = dptr[0], c1 = dptr[1], c2 = dptr[2];
                    mask[j*2 - 1] = cdiff(-3) || cdiff(3) || cdiff(-dstep-3) || cdiff(-dstep) ||
                        cdiff(-dstep+3) || cdiff(dstep-3) || cdiff(dstep) || cdiff(dstep+3);
                }
            }

            cv::dilate( m, m, cv::Mat() );  //對m膨脹
            mask = m.data;
        }

        dptr = dst_pyramid[level].data;        //dptr指向影象矩陣起始地址
        dstep = (int)dst_pyramid[level].step;  //dstep表示影象矩陣每一行的佔記憶體的位元組數

        for( i = 0; i < size.height; i++, sptr += sstep - size.width*3,  
                                          dptr += dstep - size.width*3,  //每處理完一行,sptr和dptr都指向下一行的起始地址
                                          mask += mstep )
        {
            for( j = 0; j < size.width; j++, sptr += 3, dptr += 3 )   //每處理完一列,sptr和dptr都指向同行下一列畫素的起始地址,所以sptr和dptr實際就是每個畫素點的地址
            {
                int x0 = j, y0 = i, x1, y1, iter;
                int c0, c1, c2;

                if( mask && !mask[j] )
                    continue;

                c0 = sptr[0], c1 = sptr[1], c2 = sptr[2];    //分別對應畫素點三通道的地址

                // iterate meanshift procedure
                for( iter = 0; iter < termcrit.max_iter; iter++ )
                {
                    uchar* ptr;
                    int x, y, count = 0;
                    int minx, miny, maxx, maxy;
                    int s0 = 0, s1 = 0, s2 = 0, sx = 0, sy = 0;      //(x,y)的迭代的座標值,(s0,s1,s2)是迭代的3通道分量值
                    double icount;
                    int stop_flag;

                    //mean shift: process pixels in window (p-sigmaSp)x(p+sigmaSp)
                    minx = cvRound(x0 - sp); minx = MAX(minx, 0);              //若j-sp>=0,則minx=(j-sp),否則,minx=0;
                    miny = cvRound(y0 - sp); miny = MAX(miny, 0);              //若i-sp>=0,則miny=(i-sp),否則,miny=0;
                    maxx = cvRound(x0 + sp); maxx = MIN(maxx, size.width-1);   //若j+sp<=width+1,則maxx=j+sp,否則,maxx=width-1;
                    maxy = cvRound(y0 + sp); maxy = MIN(maxy, size.height-1);  //若i+sp<=height+1,則maxy=i+sp,否則,maxy=height-1;
                    ptr = sptr + (miny - i)*sstep + (minx - j)*3;  //sptr指向(i,j),ptr則指向當前視窗第一個畫素點

                    for( y = miny; y <= maxy; y++, ptr += sstep - (maxx-minx+1)*3 )  //視窗內,每處理完一行,ptr指向下一行首地址
                    {
                        int row_count = 0;
                        x = minx;
                        #if CV_ENABLE_UNROLLED
                        for( ; x + 3 <= maxx; x += 4, ptr += 12 )  //這兩次for迴圈是什麼意思?顏色限定和空間限定?
                        {
                            int t0 = ptr[0], t1 = ptr[1], t2 = ptr[2];  
                            if( tab[t0-c0+255] + tab[t1-c1+255] + tab[t2-c2+255] <= isr2 )
                            {
                                s0 += t0; s1 += t1; s2 += t2;
                                sx += x; row_count++;
                            }
                            t0 = ptr[3], t1 = ptr[4], t2 = ptr[5];
                            if( tab[t0-c0+255] + tab[t1-c1+255] + tab[t2-c2+255] <= isr2 )
                            {
                                s0 += t0; s1 += t1; s2 += t2;
                                sx += x+1; row_count++;
                            }
                            t0 = ptr[6], t1 = ptr[7], t2 = ptr[8];
                            if( tab[t0-c0+255] + tab[t1-c1+255] + tab[t2-c2+255] <= isr2 )
                            {
                                s0 += t0; s1 += t1; s2 += t2;
                                sx += x+2; row_count++;
                            }
                            t0 = ptr[9], t1 = ptr[10], t2 = ptr[11];
                            if( tab[t0-c0+255] + tab[t1-c1+255] + tab[t2-c2+255] <= isr2 )
                            {
                                s0 += t0; s1 += t1; s2 += t2;
                                sx += x+3; row_count++;
                            }
                        }
                        #endif
                        for( ; x <= maxx; x++, ptr += 3 )
                        {
                            int t0 = ptr[0], t1 = ptr[1], t2 = ptr[2];
                            if( tab[t0-c0+255] + tab[t1-c1+255] + tab[t2-c2+255] <= isr2 )
                            {
                                s0 += t0; s1 += t1; s2 += t2;
                                sx += x; row_count++;
                            }
                        }
                        count += row_count;
                        sy += y*row_count;
                    }

                    if( count == 0 )
                        break;

                    icount = 1./count;
                    x1 = cvRound(sx*icount);
                    y1 = cvRound(sy*icount);
                    s0 = cvRound(s0*icount);
                    s1 = cvRound(s1*icount);
                    s2 = cvRound(s2*icount);

                    stop_flag = (x0 == x1 && y0 == y1) || abs(x1-x0) + abs(y1-y0) +
                        tab[s0 - c0 + 255] + tab[s1 - c1 + 255] +
                        tab[s2 - c2 + 255] <= termcrit.epsilon;

                    x0 = x1; y0 = y1;
                    c0 = s0; c1 = s1; c2 = s2;

                    if( stop_flag )
                        break;
                }

                dptr[0] = (uchar)c0;
                dptr[1] = (uchar)c1;
                dptr[2] = (uchar)c2;
            }
        }
    }
}

void cv::pyrMeanShiftFiltering( InputArray _src, OutputArray _dst,
                                double sp, double sr, int maxLevel,
                                TermCriteria termcrit )
{
    Mat src = _src.getMat();

    if( src.empty() )
        return;

    _dst.create( src.size(), src.type() );
    CvMat c_src = src, c_dst = _dst.getMat();
    cvPyrMeanShiftFiltering( &c_src, &c_dst, sp, sr, maxLevel, termcrit );
}</span><span style="font-size:18px;">
</span>