1. 程式人生 > >opencv 基於KNN的手寫數字字元識別

opencv 基於KNN的手寫數字字元識別

樓主為武漢市某科技大學的機械小碩,由於某種原因,開始學習和使用opencv,所以算是半路出家和非科班出身,如有描述的不夠專業地方,還請多多包涵和批評指正。

本文主要實現對手寫數字字元的識別,主要用到的方法為k-近鄰分類方法,用到opencv提供的KNearest類。

也是在網上看到的程式碼,覺得很好玩,然後下載了工程,原工程是opencv2早期版本,還是cvmat的時代,看起來很不方便和習慣。

樓主花了點兒時間,好好學習了下,然後修改成了opencv2後期mat時代的程式碼,opencv2.4.9-2.4.13應該都可以跑起來。

先講下原理吧

1、得到訓練的資料,一般都會是兩個矩陣,一個矩陣存放著資料,另一個矩陣存放資料對應的標記(如數字0,1,2,3....)

2、訓練資料,這一步,很簡單,一個函式就可以搞定

3、根據需要識別的圖片,預測其屬於哪一類。

總結來講:既然opencv都為我們封裝好了演算法,提供了一個可供呼叫的類,使用起來,必然是比較簡單的。大部分的精力和程式碼,都花在得到標準化的資料上。

話不多說,先上一部分程式碼吧

class basicOCR
{
public:
	float classify(Mat img, int showResult);
	basicOCR();
	void test();
private:
	char file_path[255];
	int train_samples;
	int classes;
	Mat trainData;
	Mat trainClasses;
	int size;
	static const int K = 5;//最大鄰居個數
	KNearest *knn;
	void getData();
	void train();
	void preprocessing(Mat &srcimage, int new_width, int new_height);
};
封裝成一個類
</pre><pre name="code" class="cpp">basicOCR::basicOCR()//建構函式
{

	//initial
	sprintf(file_path, "OCR/");
	train_samples = 50;//訓練樣本,總共100個,50個訓練,50個測試
	classes = 10;//暫時識別十個數字

	size = 128;//

	trainData.create(train_samples*classes, size*size, CV_32FC1);//訓練資料的矩陣
	trainClasses.create(train_samples*classes, 1, CV_32FC1);

	//Get data (get images and process it)
	getData();

	//train	
	train();
	//Test	
	test();

	printf(" ------------------------------------------------------------------------\n");
	printf("|\t識別結果\t|\t 測試精度\t|\t  準確率\t|\n");
	printf(" ------------------------------------------------------------------------\n");
}
類的建構函式,可以看到,存放訓練資料的矩陣trainData和存放對應標記的矩陣trainClasses
void basicOCR::getData()
{
	Mat src_image;
	char file[255];
	int i, j;
	for (i = 0; i<classes; i++)
	{
		for (j = 0; j< train_samples; j++)
		{

			//載入pbm格式影象,作為訓練
			if (j<10)
				sprintf(file, "%s%d/%d0%d.pbm", file_path, i, i, j);
			else
				sprintf(file, "%s%d/%d%d.pbm", file_path, i, i, j);
			src_image = imread(file, 0);
			if (src_image.empty())
			{
				printf("Error: Cant load image %s\n", file);
				//exit(-1);
			}
			//process file
			preprocessing(src_image, size, size);

			//Set data 
			float* data1 = trainData.ptr<float>(i*train_samples+j);
			float* data2 = src_image.ptr<float>(0);
			for (int k = 0; k < src_image.cols; k++)
			{
				data1[k] = data2[k];
			}

			//Set class label
			trainClasses.at<float>(i*train_samples + j, 0) = i;
		}
	}
}

得到資料矩陣和相應的標記矩陣。

好了,只貼部分程式碼吧,如需要完整的工程,可以到hust平凡之路下載。

請原諒我這種賺取積分的行為。畢竟,人生已經如此的艱難。