1. 程式人生 > >【SVM理論到實踐4】基於OpenCv中的SVM的手寫體數字識別

【SVM理論到實踐4】基於OpenCv中的SVM的手寫體數字識別

//由於本人每天時間非常緊張,所以細節寫的不詳細,部落格僅供各位參考,裡面的程式碼都是執行過的,直接可以執行

本章的學習目標:

     1)手寫體數字識別資料庫MNIST

     2)基於SVM訓練的具體步驟  

1)手寫體數字識別資料庫MNIST

MNIST(Mixed National Institute of Standards and Technology)是一個大型的手寫體數字識別資料庫,廣泛應用於機器學習領域的訓練和測試,由紐約大學Yann LeCun教授整理。MNIST包括60000個訓練集和10000個測試集,每張圖都已經進行了尺寸歸一化,數字居中處理,固定為28*28畫素。具體的下載地址如下所示:

2)基於SVM訓練的具體步驟

訓練的過程如下所示:

 1)讀取Mnist訓練集資料

 2)訓練

 3)讀取Mnist訓練集資料,對比預測結果,得到錯誤率

3)具體的實現如下所示:

  1)mnist給出的資料檔案是二進位制檔案,四個檔案解壓之後的情況如下所示:

      1)”train-images.idx3-ubyte”二進位制檔案,儲存了標頭檔案資訊以及60000張28*28解析度的影象資訊(用於訓練)

      2)”train-labels.idx1-ubyte”二進位制檔案,儲存了檔案頭資訊以及60000張label資訊

      3)”t10k-images.idx-ubyte”二進位制檔案,儲存了檔案頭資訊以及

10000張28*28解析度的圖想資訊(用於測試)

      4)“t10k-labels.idx-ubyte”二進位制檔案,儲存了標頭檔案資訊以及10000張影象label資訊

  2)因為OpenCv中沒有直接匯入MNIST資料的檔案,所以需要自己寫函式來讀取MNIST的資料檔案

 1)首先,要知道MNIST資料的資料格式:IMAGE_FILE---包含四個int型的頭部資料(magic_number,number_of_images,

             number_of_rows,number_of_columns)

       2)餘下的每一個byte表示一個pixel的資料,範圍是0~255(可以在讀入的時候scale到0~1的區間)

       3)LABEL_FILE包含兩個型的頭部資料(magic_number,number_of_items),餘下的每一個byte表示一個label資料,範圍是0~9

             我們可以參考下圖所示,更加具體的資訊可以去MNIST官網瞭解:

            

   3)此塊要注意的第一個坑是:MNIST是大端儲存,然而大部分的intel處理器都是小端儲存,所以對於int、long、float這些多位元組的資料型別,就要一個一個byte地翻轉過來,才能正確的顯示。

   4)此塊注意的第二個坑是:如果用第一條開啟檔案,不會報錯,但是資料會出現錯誤,頭部資料仍然正確,但是後面的pixel資料大部分都是0

不能用ifstream file(fileName);

而要改成ifstream file(fileName, ios::binary);

   5)此塊注意的第三個坑是:training時,IMAGE和LABEL的資料分別都放進一個MAT中儲存,但是隻能是CV32_F或者是CV32_S的格式,不然會報錯,OpenCv文件給出的例子是這樣的:(但是predict的時候又會要求label的格式是unsigned int),所以...可以設定data的Mat格式為CV_32FC1,label的Mat格式為CV_32SC1,當然,最好都設定為CV_32FC1

6)順便地,影象訓練資料的轉換格式,也就是說,我們都進來的影象資料都是二維的矩陣,但是我們在訓練的時候,需要把二維的影象矩陣拉為一維的向量

7)最後,為了驗證資料的正確性,一個有效的辦法就是輸出第一個和最後一個數據和原始影象的資料進行對比

8)還有需要說明的一點是,此處,我們是直接對原始影象進行訓練,並沒有對任何對影象的任何特徵進行提取;我們也可以在影象進行訓練之前,先利用Harris,SIFT,SURF,FAST,BRIRF,ORB,HOG這些提取影象的特徵,然後再把提取的特徵向量組成訓練集進行訓練。

/***************************************************************************************************** 
檔案描述: 
        標頭檔案mnist.h
開發環境: 
        VS2012 + OpenGl(GLUT3.7) + OpenCv2.4.9 + Halcon10.0 
時間地點: 
        陝西師範大學----2017.3.3
作    者: 
        九月 
*****************************************************************************************************/ 
#ifndef MNIST_H
#define MNIST_H
#include<iostream>
#include<string>
#include<fstream>
#include<ctime>
#include<opencv2/core/core.hpp>
#include<opencv2/highgui/highgui.hpp>
#include<opencv2/imgproc/imgproc.hpp>
#include<opencv2/ml/ml.hpp>

using namespace std;
using namespace cv;

int     ReverseInt(int i);                      //[1]大小端儲存轉換
cv::Mat ReadMnistImage(const string fileName);  //[2]讀取Image的資料資訊
cv::Mat ReadMnistLabel(const string fileName);  //[3]讀取Label資料資訊
#endif

/***************************************************************************************************** 
檔案描述: 
        標頭檔案mnist.h的實現檔案mnist.cpp
開發環境: 
        VS2012 + OpenGl(GLUT3.7) + OpenCv2.4.9 + Halcon10.0 
時間地點: 
        陝西師範大學----2017.3.3
作    者: 
        九月 
*****************************************************************************************************/ 
#include"mnist.h"
#include<ctime>
#include<iostream>

using namespace std;

int  testNum = 10000;
/***************************************************************************************************** 
函式功能:
         大小端儲存模式的資料轉換
*****************************************************************************************************/
int  ReverseInt(int i)
{
	unsigned char c1;
	unsigned char c2;
	unsigned char c3;
	unsigned char c4;

	c1 = i&255;
	c2 = (i>>8)&255;
	c3 = (i>>16)&255;
	c4 = (i>>24)&255;

	return ((int)c1<<24)+((int)c2<<16)+((int)c3<<8)+c4;
}
/***************************************************************************************************** 
函式功能:
         讀取Minst資料庫的影象二進位制檔案
注意問題:
         此塊我們需要注意的問題是:當我們從MINIST資料庫中讀進來影象檔案後,我們將讀進來的檔案儲存在
		 dataMat矩陣容器中,這就是我們送給SVM的訓練樣本;要注意的是,在這個矩陣容器中,矩陣dataMat
		 中的一行,就代表一個樣本,就是實際中的一幅圖片;
		 我們有多少張圖片,我們就有多少訓練樣本,這個矩陣就有多少行。
*****************************************************************************************************/
cv::Mat ReadMnistImage(const string fileName)
{
	double      constTime;
    std::clock_t startTime;
    std::clock_t endTime;

	int magicNumber    = 0;
	int numberOfImages = 0;
	int nRows          = 0;
	int nCols          = 0;
	
	cv::Mat dataMat;
	std::ifstream file(fileName,ios::binary);

	if(file.is_open())
	{
		std::cout<<"[NOTICE]The set of Images is opened sucessfully!"<<std::endl;
		file.read((char*)&magicNumber,sizeof(magicNumber));
		file.read((char*)&numberOfImages,sizeof(numberOfImages));
		file.read((char*)&nRows,sizeof(nRows));
		file.read((char*)&nCols,sizeof(nCols));

		magicNumber    = ReverseInt(magicNumber);
		numberOfImages = ReverseInt(numberOfImages);
		nRows          = ReverseInt(nRows);
		nCols          = ReverseInt(nCols);

		std::cout<<"[1]magicNumber    = "<<magicNumber<<std::endl;
		std::cout<<"[2]numberOfImages = "<<numberOfImages<<std::endl;
		std::cout<<"[3]nRows          = "<<nRows<<std::endl;
		std::cout<<"[4]nCols          = "<<nCols<<std::endl;

		//輸出第一張和最後一張圖片,檢查資料無誤
		cv::Mat s = cv::Mat::zeros(nCols,nRows*nCols,CV_32FC1);
		cv::Mat e = cv::Mat::zeros(nCols,nRows*nCols,CV_32FC1);

		std::cout<<"[NOTICE]Read the data of Imagess---->>Start!"<<std::endl;
		startTime = std::clock();
		
		dataMat =  cv::Mat::zeros(numberOfImages,nRows*nCols,CV_32FC1);
		//for(int i=0;i<numberOfImages;i++)
		for(int i=0;i<1000;i++)
		{
			for(int j=0;j<nRows*nCols;j++)
			{
				unsigned char temp = 0;
				file.read((char*)&temp,sizeof(temp));
				//std::cout<<"temp = "<<temp<<std::endl;
				float pixelValue = (float)((temp+0.0)/255.0);
				//std::cout<<"pixelValue = "<<pixelValue<<std::endl;
				dataMat.at<float>(i,j) = pixelValue;
				//列印第一張和最後一張影象資料
				if(i==0)
				{
					s.at<float>(j/nCols,j%nCols) = pixelValue;
				}
				else if(i==numberOfImages-1)
				{
					e.at<float>(j/nCols,j%nCols) = pixelValue;
				}
			}//for j
		}//for i
		endTime  = std::clock();
		constTime= (endTime-startTime);
		std::cout<<"[NOTICE]Read the data of Images---->>Finish!"<<std::endl;
		std::cout<<"[NOTICE]Running Time = "<<constTime<<"ms"<<std::endl;
		cv::imshow("firstImage",s);
		cv::imshow("last image",e);
		cv::waitKey(0);
	}
	file.close();
	return dataMat;
}
/***************************************************************************************************** 
函式功能:
         讀Mnist資料庫中的標籤
*****************************************************************************************************/
 cv::Mat ReadMnistLabel(const string fileName) 
 {
	  double constTime;
      std::clock_t startTime;
      std::clock_t endTime;

      int magicNumber   = 0;
      int numberOfItems = 0;
 
      cv::Mat labelMat;
 
      std::ifstream file(fileName,ios::binary);
      if (file.is_open())
      {
          std::cout<<"[NOTICE]The set of Label is opened sucessfully!"<<std::endl;
          file.read((char*)&magicNumber,sizeof(magicNumber));
          file.read((char*)&numberOfItems,sizeof(numberOfItems));
          magicNumber   = ReverseInt(magicNumber);
          numberOfItems = ReverseInt(numberOfItems);
 
		  std::cout<<"[1]magicNumber    = "<<magicNumber<<std::endl;
		  std::cout<<"[2]numberOfItems  = "<<numberOfItems<<std::endl;


          //記錄第一個label和最後一個label
          unsigned int s = 0;
		  unsigned int e = 0;

          std::cout<<"[NOTICE]Read the data of Labels---->>Start!"<<std::endl;
		  startTime = std::clock();
          labelMat = Mat::zeros(numberOfItems,1,CV_32SC1);
          //for (int i = 0; i < numberOfItems; i++) 
		  for (int i = 0; i < 1000; i++) 
		  {
             unsigned char temp = 0;

             file.read((char*)&temp,sizeof(temp));

             labelMat.at<unsigned int>(i, 0) = (unsigned int)temp;
 
             //列印第一個和最後一個label
             if(i == 0)
			 {
					 s = (unsigned int)temp;
			  }
             else if(i == numberOfItems-1)
			 {
					 e = (unsigned int)temp;
			  }
          }
		  endTime = clock();
		  constTime= (endTime-startTime);
		  std::cout<<"[NOTICE]Read the data of Images---->>Finish!"<<std::endl;
		  std::cout<<"[NOTICE]Running Time = "<<constTime<<"ms"<<std::endl;
          std::cout<<"[1]first label = " << s << endl;
          std::cout<<"[2]last  label = " << e << endl;
     }
     file.close();
     return labelMat;
 }
/***************************************************************************************************** 
程式功能: 
        基於OpenCv中SVM的Minist手寫體字元識別
開發環境: 
        VS2012 + OpenGl(GLUT3.7) + OpenCv2.4.9 + Halcon10.0 
時間地點: 
        陝西師範大學----2017.3.3
作    者: 
        九月 
*****************************************************************************************************/ 
#include"mnist.h"
#include<opencv2/core/core.hpp>
#include<opencv2/imgproc/imgproc.hpp>
#include<opencv2/highgui/highgui.hpp>
#include<opencv2/ml/ml.hpp>
#include<ctime>
#include<string>
#include<iostream>
  
using namespace std;
using namespace cv;


std::string trainImage = "mnist_dataset/train-images.idx3-ubyte";
std::string trainLabel = "mnist_dataset/train-labels.idx1-ubyte";
std::string testImage  = "mnist_dataset/t10k-images.idx3-ubyte";
std::string testLabel  = "mnist_dataset/t10k-labels.idx1-ubyte";

int main()
{
	double consumeTime    = 0;
	std::clock_t startTime = 0;
	std::clock_t endTime   = 0;

	cv::Mat trainData;
	cv::Mat trainDataLabels;
	//【1】讀入訓練樣本
	trainData       = ReadMnistImage(trainImage);                        
	trainDataLabels = ReadMnistLabel(trainLabel);
	std::cout<<"[1]trainData.rows*trainData.cols             = "<<trainData.rows<<"*"<<trainData.cols<<std::endl;
	std::cout<<"[2]trainDataLabels.rows*trainDataLabels.cols = "<<trainDataLabels.rows<<"*"<<trainDataLabels.cols<<std::endl;
	//【2】設定支援向量機的引數,SVM中的引數有很多,但是與C_SVC有關的就只有gamma和C,所以只要設定好這兩個就可以了
	//     其實,很多資料將gamma設定為0.01,這樣訓練的收斂速度就會快很多
	CvSVMParams params;
	params.svm_type    = SVM::C_SVC;
	params.kernel_type = SVM::RBF;
	params.degree      = 10.0;
	params.gamma       = 0.01;
	params.coef0       = 1.0;
	params.C           = 10.0;
	params.nu          = 0.5;
	params.p           = 0.1;
	params.term_crit   = cv::TermCriteria(CV_TERMCRIT_EPS,1000,FLT_EPSILON);
	//【3】訓練SVM
	std::cout<<"[NOTICE]Starting training process!"<<std::endl;
	startTime = std::clock();
	CvSVM svm;
	svm.train(trainData,trainDataLabels,cv::Mat(),cv::Mat(),params);
	endTime   = std::clock();
	consumeTime   = (endTime - startTime);
	std::cout<<"[NOTICE]Finished training process...consumeTime = "<<consumeTime<<"ms"<<std::endl;
	svm.save("mnist_dataset/mnist_svm.xml");
	std::cout<<"[NOTICE]Save as /mnist_dataset/mnist_svm.xml"<<std::endl;
    //【4】開始匯入預測樣本
	std::cout<<"[NOTICE]Loading the predict sample!"<<std::endl;
    cv::Mat     testData;
	cv::Mat     testLabels;
	std::cout<<"[NOTICE]Loading sucessfully!"<<std::endl;
	testData  = ReadMnistImage(testImage);
	testLabels= ReadMnistLabel(testLabel);
    //【5】SVM利用訓練好的模型開始進行預測
	float count = 0;
	//for(int i=0;i<testData.rows;i++)
	for(int i=0;i<1000;i++)
	{
		cv::Mat sample = testData.row(i);
		float result  = svm.predict(sample);
		result = std::abs(result-testLabels.at<unsigned int>(i,0)<=FLT_EPSILON?1.f:0.f);
		count += result;
	}
	//【6】統計預測的正確個數和錯誤率
	std::cout<<"[NOTICE]Correct identification number = "<<count<<std::endl;
	std::cout<<"[NOTICE]Error rate = "<<(1000-count+0.0)/1000 * 100.0<<"%..."<<std::endl;
	std::system("pause");
	return 0;
}
下面兩幅圖片是博主訓練1000張圖片的準確率,具體怎樣設定,請看程式碼