grabcut in one-cut 一種好用快速的影象分割演算法
2013年iCCV上的這篇論文,提出了一種快速的基於簡單互動的分割演算法,本篇博文是對該論文的解讀。
Tang M, Gorelick L, Veksler O, et al. GrabCut inOne Cut[C]// IEEE International Conference on Computer Vision. IEEE ComputerSociety, 2013:1769-1776.
grabcut in onecut 基於傳統的graph cut分割方法,這是一種非常流行地能量優化演算法,這類方法把影象分割問題與凸的最小割問題相關聯,首先用一個無向圖G=<V,E>,表示要分割的影象,V,E分別是頂點(vertex)和邊(edge)的幾何。與傳統的圖有所不同,graph cuts在普通圖的基礎上加了兩個頂點,這兩個頂點分別記做source和sink,統稱為終端頂點。這樣在Graphs cuts中就有了兩種頂點,也就有了兩種邊。
第一種頂點和邊是:第一種普通頂點對應於影象中的每個畫素。每兩個鄰域頂點(對應於影象中每兩個鄰域畫素)的連線就是一條邊。這種邊也叫n-links。
第二種頂點和邊是:每個普通頂點和這2個終端頂點之間都有連線,組成第二種邊。這種邊也叫t-links。
圖1 s-t圖
上圖就是一個影象對應的S-T圖,每個畫素對應於圖中的一個相應頂點,在這些頂點之外還有source頂點與sink頂點。藍色和紅色的邊表示t-links,黃色的邊表示n-links。在影象分割中,s頂點一般表示為前景目標,t頂點一般表示為背景目標。
每條邊都有權重,graph cut 中的cut是指圖中的邊集合的一個子集,cut中的所有邊的權重和被叫做cost(代價)。
Graph Cuts中的Cuts是指這樣一個邊的集合,很顯然這些邊集合包括了上面2種邊,該集合中所有邊的斷開會導致殘留”S”和”T”圖的分開,所以就稱為“割”。如果一個割,它的邊的所有權值之和最小,那麼這個就稱為最小割,也就是圖割的結果。而福特-富克森定理表明,網路的最大流max flow與最小割min cut相等。所以由Boykov和Kolmogorov發明的max-flow/min-cut演算法就可以用來獲得s-t圖的最小割。這個最小割把圖的頂點劃分為兩個不相交的子集S和T,其中s ∈S,t∈ T和S∪T=V 。這兩個子集就對應於影象的前景畫素集和背景畫素集,那就相當於完成了影象分割。
圖2 s-t最小割示意圖
邊的權重的確立,遵循這樣一種原則,前景與背景的分界處的權值應當最小,最小化圖割是用最小化能量函式得到。
公式中,L表示圖割,R(L)為區域項,B(L)為邊界項,a是權重因子,表示區域項與邊界項所佔的比例差別。區域項往往由下面的公式表示
該能量項表示為畫素P分配標籤 的懲罰, 表示為畫素p分配標籤 的懲罰,該能量項的值往往通過對比畫素p的灰度與給定的目標的灰度直方圖來獲得。
邊界項
由於邊界兩側點的畫素值差別往往比較大,因此,邊界項的作用就是當兩鄰域畫素的差別很大時,邊界項的值應當最小。
在grabcut in one cut論文中,作者將區域項替換成了下式,有效的避免NP-hard問題
上式可以轉換成下面的式子
Ω表示的是在一個bin中畫素的數量, 表示的是bin中屬於前景的畫素數量, 表示的是屬於背景的畫素數量。
圖3 附加頂點示意圖
同時作者給出了實現方法,通過增加輔助節點 ,k表示第k個bin將輔助節點與相應bin中的值相連,同時設定權值為1,這樣在最小化前景背景時,當邊界項的權值最小,同時也會相應的最小化輔助節點與相應邊的連線權重,也就是最小化了區域項。
參考資料:
http://blog.csdn.net/zouxy09/article/details/8532111
作者提供的程式碼下載地址:
本人程式碼下載:
http://download.csdn.net/download/zhangyumengs/10237656
下面給出本人改寫的部分qt程式碼,加入了註釋,方便閱讀。
同時加入了新的功能,可以進行裁剪影象,方形分割。
onecut.h
#ifndef ONECUT_H
#define ONECUT_H
#include <QWidget>
#include "ui_onecut.h"
#include <iostream>
#include <string>
#include <iomanip>
#include <sstream>
#include <opencv2/imgproc/imgproc.hpp>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include "graph.h"
#include "qmessagebox.h"
//自定義控制元件
#include "onecutlabel.h"
#include "cutimagelabel.h"
#include "rectseglabel.h"
using namespace std;
using namespace cv;
#define NEIGHBORHOOD_4_TYPE 1;
const int NEIGHBORHOOD = NEIGHBORHOOD_4_TYPE;
class onecut : public QWidget
{
Q_OBJECT
public:
onecut(QWidget *parent = 0);
~onecut();
//___________________________________________________________________________________________________
Mat inputImg, showImg, binPerPixelImg, showEdgesImg, segMask, segShowImg;
Mat fgScribbleMask, bgScribbleMask;
//用於撤銷上一次操作
Mat fgScribbleMask_last, bgScribbleMask_last;
int lastSegState = 1;
Mat showImg_last;
int numUsedBins = 0;
float varianceSquared = 0;
int scribbleRadius = 10;
float bha_slope = 0.5f;
int numBinsPerChannel = 16;
float EDGE_STRENGTH_WEIGHT = 0.95f;
const float INT32_CONST = 1000;
const float HARD_CONSTRAINT_CONST = 1000;
int init(Mat src);
void destroyAll();
// 為每一個畫素計算index
void generateBinIndex(Mat& bin, Mat& inImg, int binschannel, int& binsNotEmpty);
//計算高斯分佈的方差
void generateEdgeVariance(Mat& inputImg, Mat& showEdgesImg, float& varianceSquared);
typedef Graph<int, int, int> GraphType;
GraphType *myGraph;
private slots:
void onSegImage();
void onMouseMoveFinish(Mat bgScribbleMask,
Mat fgScribbleMask,
Mat showImg);
void onFinish();
void onCutImage();
void onRectSeg();
void onConfirmCut();
void onDrawImage();
void onLineWidthChanged(int);
signals:
void okClicked(Mat result);
private:
Ui::onecut ui;
//裁剪相關變數
cutImageLabel* imageLable = NULL;
rectSegLabel* rectLabel = NULL;
Mat rectCutImage(const Mat& src, Rect rect);
void showImage(Mat image);
QImage cvMatToQImage(Mat& src);
protected:
void keyPressEvent(QKeyEvent *event);
};
#endif
onecut.cpp
#include "onecut.h"
onecut::onecut(QWidget *parent)
: QWidget(parent)
{
ui.setupUi(this);
connect(ui.label_show, SIGNAL(mouseMoveFinish(Mat, Mat, Mat)),
this, SLOT(onMouseMoveFinish(Mat ,Mat ,Mat)));
connect(ui.button_seg, SIGNAL(clicked()),
this, SLOT(onSegImage()));
connect(ui.button_ok, SIGNAL(clicked()),
this, SLOT(onFinish()));
//裁剪塗抹相關
connect(ui.button_cut, SIGNAL(clicked()),
this, SLOT(onCutImage()));
connect(ui.button_confirmcut, SIGNAL(clicked()),
this, SLOT(onConfirmCut()));
connect(ui.button_draw, SIGNAL(clicked()),
this, SLOT(onDrawImage()));
connect(ui.button_rectseg, SIGNAL(clicked()),
this, SLOT(onRectSeg()));
//改變繪製線寬度
ui.slider_linewidth->setRange(2, 10);
ui.slider_linewidth->setValue(10);
connect(ui.slider_linewidth, SIGNAL(valueChanged(int)),
this, SLOT(onLineWidthChanged(int)));
}
onecut::~onecut()
{
}
void onecut::destroyAll()
{
// clear all data
fgScribbleMask.release();
bgScribbleMask.release();
inputImg.release();
showImg.release();
showEdgesImg.release();
binPerPixelImg.release();
segMask.release();
segShowImg.release();
delete myGraph;
}
int onecut::init(Mat src)
{
// 初始化Mat
inputImg = src.clone();
this->showImage(inputImg);
showImg = inputImg.clone();
segShowImg = inputImg.clone();
// 檢查輸入的合理性
if (!inputImg.data)
{
return -1;
}
// 初始化塗鴉
fgScribbleMask.create(2, inputImg.size, CV_8UC1);
fgScribbleMask = 0;
bgScribbleMask.create(2, inputImg.size, CV_8UC1);
bgScribbleMask = 0;
segMask.create(2, inputImg.size, CV_8UC1);
segMask = 0;
showEdgesImg.create(2, inputImg.size, CV_32FC1);
showEdgesImg = 0;
binPerPixelImg.create(2, inputImg.size, CV_32F);
// numBinsPerChannel = 16 numUsedBins 表示的是一共有多少個bin的頻數不為0
generateBinIndex(binPerPixelImg, inputImg, numBinsPerChannel, numUsedBins);
// 計算方差
generateEdgeVariance(inputImg, showEdgesImg, varianceSquared);
//相當於添加了一些輔助節點
myGraph = new GraphType( inputImg.rows * inputImg.cols + numUsedBins,12 * inputImg.rows * inputImg.cols);
int currNodeId = myGraph->add_node((int)inputImg.cols * inputImg.rows + numUsedBins);
for (int i = 0; i<inputImg.rows; i++)
{
for (int j = 0; j<inputImg.cols; j++)
{
// 當前畫素座標
int currNodeId = i * inputImg.cols + j;
float b = (float)inputImg.at<Vec3b>(i, j)[0];
float g = (float)inputImg.at<Vec3b>(i, j)[1];
float r = (float)inputImg.at<Vec3b>(i, j)[2];
for (int si = -NEIGHBORHOOD; si <= NEIGHBORHOOD; si++)
{
int ni = i + si;
//防止陣列越界
if (ni < 0 || ni >= inputImg.rows)
continue;
for (int sj = 0; sj <= NEIGHBORHOOD; sj++)
{
int nj = j + sj;
if (nj < 0 || nj >= inputImg.cols)
continue;
// 忽略相同點
// down pointed edge, this edge will be counted as an up edge for the other pixel
if (si >= 0 && sj == 0)
continue;
// 相當於在一個圓形區域內計算
if ((si*si + sj*sj) > NEIGHBORHOOD*NEIGHBORHOOD)
continue;
// 鄰域點的ID
int nNodeId = (i + si) * inputImg.cols + (j + sj);
float nb = (float)inputImg.at<Vec3b>(i + si, j + sj)[0];
float ng = (float)inputImg.at<Vec3b>(i + si, j + sj)[1];
float nr = (float)inputImg.at<Vec3b>(i + si, j + sj)[2];
//邊界項權重
float currEdgeStrength = exp(-((b - nb)*(b - nb) + (g - ng)*(g - ng) + (r - nr)*(r - nr)) / (2 * varianceSquared));
//與當前點的距離
float currDist = sqrt((float)si*(float)si + (float)sj*(float)sj);
// 計算權重
currEdgeStrength = ((float)EDGE_STRENGTH_WEIGHT * currEdgeStrength + (float)(1 - EDGE_STRENGTH_WEIGHT)) / currDist;
int edgeCapacity = (int)ceil(INT32_CONST*currEdgeStrength + 0.5);
myGraph->add_edge(currNodeId, nNodeId, edgeCapacity, edgeCapacity);
}
}
// 加入當前節點與附加節點的
int currBin = (int)binPerPixelImg.at<float>(i, j);
myGraph->add_edge(currNodeId, (int)(currBin + inputImg.rows * inputImg.cols), (int)ceil(INT32_CONST*bha_slope + 0.5), (int)ceil(INT32_CONST*bha_slope + 0.5));
}
}
//ui.label_show->setRelatedVar(bgScribbleMask, bgScribbleMaskAll, fgScribbleMask, fgScribbleMaskAll, scribbleRadius, showImg);
ui.label_show->setRelatedVar(bgScribbleMask, fgScribbleMask, scribbleRadius, showImg);
return 0;
}
//獲取每一個畫素的bin index
void onecut::generateBinIndex(Mat& bin, Mat & inImg, int binschannel, int & numUsedBins)
{
// 記錄bin 是否被使用
vector<int> occupiedBinNewIdx((int)pow((double)binschannel, (double)3), -1);
int newBinIdx = 0;
for (int i = 0; i<inImg.rows; i++)
for (int j = 0; j<inImg.cols; j++)
{
float b = (float)inImg.at<Vec3b>(i, j)[0];
float g = (float)inImg.at<Vec3b>(i, j)[1];
float r = (float)inImg.at<Vec3b>(i, j)[2];
//計算bin index
int bin_index = (int)(floor(b / 256.0 *(float)binschannel) + (float)binschannel * floor(g / 256.0*(float)binschannel)
+ (float)binschannel * (float)binschannel * floor(r / 256.0*(float)binschannel));
// 若這個bin目前沒有被使用
if (occupiedBinNewIdx[bin_index] == -1)
{
// 記錄下這個bin對應的index
occupiedBinNewIdx[bin_index] = newBinIdx;
newBinIdx++;
}
bin.at<float>(i, j) = (float)occupiedBinNewIdx[bin_index];
}
double maxBin;
minMaxLoc(bin, NULL, &maxBin);
numUsedBins = (int)maxBin + 1;
occupiedBinNewIdx.clear();
}
// 計算邊權重
void onecut::generateEdgeVariance(Mat & inputImg, Mat & showEdgesImg, float & varianceSquared)
{
varianceSquared = 0;
int counter = 0;
for (int i = 0; i<inputImg.rows; i++)
{
for (int j = 0; j<inputImg.cols; j++)
{
float b = (float)inputImg.at<Vec3b>(i, j)[0];
float g = (float)inputImg.at<Vec3b>(i, j)[1];
float r = (float)inputImg.at<Vec3b>(i, j)[2];
for (int si = -NEIGHBORHOOD; si <= NEIGHBORHOOD && si + i < inputImg.rows && si + i >= 0; si++)
{
for (int sj = 0; sj <= NEIGHBORHOOD && sj + j < inputImg.cols; sj++)
{
if ((si == 0 && sj == 0) ||
(si == 1 && sj == 0) ||
(si == NEIGHBORHOOD && sj == 0))
continue;
float nb = (float)inputImg.at<Vec3b>(i + si, j + sj)[0];
float ng = (float)inputImg.at<Vec3b>(i + si, j + sj)[1];
float nr = (float)inputImg.at<Vec3b>(i + si, j + sj)[2];
varianceSquared += (b - nb)*(b - nb) + (g - ng)*(g - ng) + (r - nr)*(r - nr);
counter++;
}
}
}
}
//其每一個畫素計算了一個variance 這是一個總和
varianceSquared /= counter;
}
void onecut::showImage(Mat image){
//Show Image
ui.label_show->clear();
ui.label_show->setPixmap(QPixmap::fromImage(cvMatToQImage(image)));
ui.label_show->resize(ui.label_show->pixmap()->size());
}
QImage onecut::cvMatToQImage(Mat& src){
Mat* tmp = new Mat;
QImage qImage;
if (src.channels() == 3) // RGB image
{
cvtColor(src, *tmp, CV_BGR2RGB);
qImage = QImage((const unsigned char*)(tmp->data), tmp->cols, tmp->rows, src.cols*src.channels(), QImage::Format_RGB888);
}
else // gray image
{
qImage = QImage((const uchar*)(src.data),
src.cols, src.rows,
src.cols*src.channels(), //new add
QImage::Format_Indexed8);
}
return qImage;
}
void onecut::onSegImage(){
if (rectLabel!= NULL)
if (rectLabel->isVisible())
{
if (rectLabel->startPnt.isNull())
{
QMessageBox::information(NULL, QString("Information"), QString("draw error, please draw again!"));
return;
}
if (rectLabel->endPnt.isNull())
{
QMessageBox::information(NULL, QString("Information"), QString("draw error, please draw again!"));
return;
}
if (rectLabel->startPnt.x() < 0)
{
QMessageBox::information(NULL, QString("Information"), QString("draw error, please draw again!"));
return;
}
if (rectLabel->startPnt.y() < 0)
{
QMessageBox::information(NULL, QString("Information"), QString("draw error, please draw again!"));
return;
}
if (rectLabel->endPnt.x() < 0)
{
QMessageBox::information(NULL, QString("Information"), QString("draw error, please draw again!"));
return;
}
if (rectLabel->endPnt.y() < 0)
{
QMessageBox::information(NULL, QString("Information"), QString("draw error, please draw again!"));
return;
}
if (rectLabel->startPnt.x() > rectLabel->width())
{
QMessageBox::information(NULL, QString("Information"), QString("draw error, please draw again!"));
return;
}
if (rectLabel->startPnt.y() > rectLabel->height())
{
QMessageBox::information(NULL, QString("Information"), QString("draw error, please draw again!"));
return;
}
if (rectLabel->endPnt.x() > rectLabel->width())
{
QMessageBox::information(NULL, QString("Information"), QString("draw error, please draw again!"));
return;
}
if (rectLabel->endPnt.y() > rectLabel->height())
{
QMessageBox::information(NULL, QString("Information"), QString("draw error, please draw again!"));
return;
}
if (rectLabel->startPnt.x() == rectLabel->endPnt.x())
{
QMessageBox::information(NULL, QString("Information"), QString("draw error, please draw again!"));
return;
}
if (rectLabel->startPnt.y() == rectLabel->endPnt.y())
{
QMessageBox::information(NULL, QString("Information"), QString("draw error, please draw again!"));
return;
}
for (int i = 0; i < inputImg.rows; i++)
for (int j = 0; j < rectLabel->startPnt.x(); j++)
{
bgScribbleMask.at<uchar>(i, j) = 255;
}
for (int i = 0; i < rectLabel->startPnt.y(); i++)
for (int j = 0; j < inputImg.cols; j++)
{
bgScribbleMask.at<uchar>(i, j) = 255;
}
for (int i = rectLabel->endPnt.y(); i < rectLabel->height(); i++)
for (int j = 0; j < inputImg.cols; j++)
{
bgScribbleMask.at<uchar>(i, j) = 255;
}
for (int i = 0; i < inputImg.rows; i++)
for (int j = rectLabel->endPnt.x(); j < inputImg.cols; j++)
{
bgScribbleMask.at<uchar>(i, j) = 255;
}
this->rectLabel->close();
this->rectLabel = NULL;
}
//設定硬約束
for (int i = 0; i<inputImg.rows; i++)
{
for (int j = 0; j<inputImg.cols; j++)
{
int currNodeId = i * inputImg.cols + j;
if (fgScribbleMask.at<uchar>(i, j) == 255)
myGraph->add_tweights(currNodeId, (int)ceil(INT32_CONST * HARD_CONSTRAINT_CONST + 0.5), 0);
else if (bgScribbleMask.at<uchar>(i, j) == 255)
myGraph->add_tweights(currNodeId, 0, (int)ceil(INT32_CONST * HARD_CONSTRAINT_CONST + 0.5));
}
}
//執行最大流演算法
myGraph->maxflow();
segMask = 0;
inputImg.copyTo(segShowImg);
fgScribbleMask = 0;
bgScribbleMask = 0;
for (int i = 0; i<inputImg.rows * inputImg.cols; i++)
{
if (myGraph->what_segment((int)i) == GraphType::SOURCE)
{
segMask.at<uchar>(i / inputImg.cols, i%inputImg.cols) = 255;
}
else
{
segMask.at<uchar>(i / inputImg.cols, i%inputImg.cols) = 0;
(uchar)segShowImg.at<Vec3b>(i / inputImg.cols, i%inputImg.cols)[0] = 0;
(uchar)segShowImg.at<Vec3b>(i / inputImg.cols, i%inputImg.cols)[1] = 0;
(uchar)segShowImg.at<Vec3b>(i / inputImg.cols, i%inputImg.cols)[2] = 0;
}
}
this->showImage(segShowImg);
this->showImg = segShowImg.clone();
//ui.label_show->setRelatedVar(bgScribbleMask, bgScribbleMaskAll, fgScribbleMask, fgScribbleMaskAll, scribbleRadius, showImg);
ui.label_show->setRelatedVar(bgScribbleMask, fgScribbleMask, scribbleRadius, showImg);
this->lastSegState = 1;
}
void onecut::onMouseMoveFinish(Mat bgScribbleMask,
Mat fgScribbleMask,
Mat showImg){
if (this->lastSegState == 1)
{
fgScribbleMask_last = this->fgScribbleMask.clone();
bgScribbleMask_last = this->bgScribbleMask.clone();
showImg_last = this->showImg.clone();
this->lastSegState = 0;
}
this->bgScribbleMask = bgScribbleMask.clone();
this->fgScribbleMask = fgScribbleMask.clone();
this->showImg = showImg.clone();
}
void onecut::onFinish(){
emit okClicked(this->segShowImg);
this->close();
}
void onecut::onCutImage(){
if (imageLable == NULL)
{
imageLable = new cutImageLabel();
imageLable->show();
ui.gridLayout->addWidget(imageLable, 0, 0);
}
}
void onecut::onRectSeg(){
rectLabel = new rectSegLabel();
rectLabel->resize(inputImg.cols,inputImg.rows);
rectLabel->show();
ui.gridLayout->addWidget(rectLabel, 0, 0);
}
void onecut::onConfirmCut(){
if (imageLable == NULL)
return;
if (imageLable->startPnt.isNull())
return;
if (imageLable->endPnt.isNull())
return;
if (imageLable->startPnt.x() < 0)
return;
if (imageLable->startPnt.y() < 0)
return;
if (imageLable->endPnt.x() < 0)
return;
if (imageLable->endPnt.y() < 0)
return;
if (imageLable->startPnt.x() > imageLable->width())
return;
if (imageLable->startPnt.y() > imageLable->height())
return;
if (imageLable->endPnt.x() > imageLable->width())
return;
if (imageLable->endPnt.y() > imageLable->height())
return;
if (imageLable->startPnt.x() == imageLable->endPnt.x())
return;
if (imageLable->startPnt.y() == imageLable->endPnt.y())
return;
Point cvP1(imageLable->startPnt.x(), imageLable->startPnt.y());
Point cvP2(imageLable->endPnt.x(), imageLable->endPnt.y());
Rect rect(cvP1, cvP2);
this->showImg = this->rectCutImage(showImg, rect);
this->inputImg = this->showImg.clone();
this->fgScribbleMask = 0;
this->bgScribbleMask = 0;
//ui.label_show->setRelatedVar(bgScribbleMask, bgScribbleMaskAll, fgScribbleMask, fgScribbleMaskAll, scribbleRadius, showImg);
ui.label_show->setRelatedVar(bgScribbleMask, fgScribbleMask, scribbleRadius, showImg);
this->showImage(showImg);
imageLable->close();
imageLable = NULL;
}
void onecut::onDrawImage(){
if (imageLable != NULL)
imageLable->close();
if (rectLabel != NULL)
rectLabel->close();
}
void onecut::keyPressEvent(QKeyEvent *event){
if ((event->modifiers() == Qt::ControlModifier) && (event->key() == Qt::Key_Z))
{
this->fgScribbleMask = fgScribbleMask_last.clone();
this->bgScribbleMask = bgScribbleMask_last.clone();
this->showImg = showImg_last.clone();
onSegImage();
}
if ((event->modifiers() == Qt::ControlModifier) && (event->key() == Qt::Key_R))
{
init(inputImg);
}
}
void onecut::onLineWidthChanged(int value){
scribbleRadius = value;
ui.label_show->setRadius(value);
}
Mat onecut::rectCutImage(const Mat& src, Rect rect){
Mat tmp = Mat::zeros(rect.size(), src.type());
for (int i = 0; i < tmp.rows; ++i)
for (int j = 0; j < tmp.cols; ++j)
for (int k = 0; k < 3; ++k)
tmp.at<Vec3b>(i, j)[k] = saturate_cast<uchar>(src.at<Vec3b>(rect.y + i, rect.x + j)[k]);
return tmp;
}