1. 程式人生 > >手把手教你在kaggle degit recognizer比賽用caffe達到準確率99.1%

手把手教你在kaggle degit recognizer比賽用caffe達到準確率99.1%

最近實驗室師兄佈置了參加kaggle上的digit recognizer訓練比賽,識別手寫資料。在經過兩週的探索之後,發現了一個非常surprising的結論,就是使用mnist的lenet中的lenet_train.sh加default setting,並且不劃分validation set的效果是我目前取得的最好的結果,高達0.994,leaderboard上結果如下
162 new Mokii 0.99100
linux下的具體過程:

step1:lmdb格式資料生成

1)仿照caffe自帶的examples中的mnist,新建一個examples/newmnist。從比賽官網https://www.kaggle.com/c/digit-recognizer/data下載訓練、測試資料,examples/newmnist/test.csv,examples/newmnist/train.csv。

2)使用下面的c++程式將csv格式轉換為lmdb格式。convert_data_to_lmdb.cpp

#include <iostream>  
#include <string>  
#include <sstream>  
#include <gflags/gflags.h>  
  
  
#include "boost/scoped_ptr.hpp"  
#include "gflags/gflags.h"  
#include "glog/logging.h"  
  
#include "caffe/proto/caffe.pb.h"  
#include "caffe/util/db.hpp"  
#include "caffe/util/io.hpp"  
#include "caffe/util/rng.hpp"  
  
using namespace caffe;  
using namespace std;  
using std::pair;  
using boost::scoped_ptr;  
  
/* edited by Zack 
 * argv[1] the input file, argv[2] the output file*/  
  
DEFINE_string(backend, "lmdb", "The backend for storing the result");  // get Flags_backend == lmdb  
  
int main(int argc, char **argv){  
    ::google::InitGoogleLogging(argv[0]);  
  
    #ifndef GFLAGS_GFLAGS_H_  
      namespace gflags = google;  
    #endif  
  
    if(argc < 3){  
        LOG(ERROR)<< "please check the input arguments!";  
        return 1;  
    }  
    ifstream infile(argv[1]);  
    if(!infile){  
        LOG(ERROR)<< "please check the input arguments!";  
        return 1;  
    }  
    string str;  
    int count = 0;  
    int rows = 28;  
    int cols = 28;  
    unsigned char *buffer = new  unsigned char[rows*cols];  
    stringstream ss;  
  
    Datum datum;             // this data structure store the data and label  
    datum.set_channels(1);    // the channels  
    datum.set_height(rows);    // rows  
    datum.set_width(cols);     // cols  
  
    scoped_ptr<db::DB> db(db::GetDB(FLAGS_backend));         // new DB object  
    db->Open(argv[2], db::NEW);                    // open the lmdb file to store the data  
    scoped_ptr<db::Transaction> txn(db->NewTransaction());   // new Transaction object to put and commit the data  
  
    const int kMaxKeyLength = 256;           // to save the key  
    char key_cstr[kMaxKeyLength];  
  
    bool flag= false;  
    while(getline(infile, str)){  
        if(flag == false){  
            flag = true;  
            continue;  
        }  
        int beg = 0;  
        int end = 0;  
        int str_index = 0;  
        //test  need to add this----------1  
        //datum.set_label(0);  
        while((end = str.find_first_of(',', beg)) != string::npos){  
            //cout << end << endl;  
            string dig_str = str.substr(beg, end - beg);  
            int pixes;  
            ss.clear();  
            ss << dig_str;  
            ss >> pixes;  
            // test need to delete this--------------2  
            if(beg == 0){  
                datum.set_label(pixes);  
                beg = ++ end;  
                continue;  
            }  
            buffer[str_index++] = (unsigned char)pixes;  
            beg = ++end;  
        }  
        string dig_str = str.substr(beg);  
        int pixes;  
        ss.clear();  
        ss << dig_str;  
        ss >> pixes;  
        buffer[str_index++] = (unsigned char)pixes;  
        datum.set_data(buffer, rows*cols);  
  
        int length = snprintf(key_cstr, kMaxKeyLength, "%08d", count);  
  
            // Put in db  
        string out;  
        CHECK(datum.SerializeToString(&out));              // serialize to string  
        txn->Put(string(key_cstr, length), out);        // put it, both the key and value  
  
        if (++count % 1000 == 0) {       // to commit every 1000 iteration  
          // Commit db  
          txn->Commit();  
          txn.reset(db->NewTransaction());  
          LOG(ERROR) << "Processed " << count << " files.";  
        }  
  
    }  
    // write the last batch  
      if (count % 1000 != 0) {            // commit the last batch  
        txn->Commit();  
        LOG(ERROR) << "Processed " << count << " files.";  
      }  
  
    return 0;  
}  

具體過程:將上面的程式碼直接放到自己的caffe根目錄下(注意是caffe根目錄),然後再caffe根目錄下輸入命令列 make all(這樣做直接就可以對cpp文件編譯了)。將會在build/examples/newmnist資料夾下生成幾個convert_data_to_lmdb開頭的編譯結果檔案。然後繼續在caffe根目錄下輸入命令列: ./build/examples/newmnist/convert_data_to_lmdb.bin examples/newmnist/test.csv examples/newmnist/mnist_test_lmdb --backend=lmdbconvert_data_to_lmdb

./build/examples/newmnist/convert_data_to_lmdb.bin examples/newmnist/test.csv examples/newmnist/mnist_test_lmdb --backend=lmdbconvert_data_to_lmdb

執行完將會在examples/newmnist下生成mnist_test_lmdb和mnist_train_lmdb兩個資料夾。