手把手教你在kaggle degit recognizer比賽用caffe達到準確率99.1%
阿新 • • 發佈:2019-01-29
最近實驗室師兄佈置了參加kaggle上的digit recognizer訓練比賽,識別手寫資料。在經過兩週的探索之後,發現了一個非常surprising的結論,就是使用mnist的lenet中的lenet_train.sh加default setting,並且不劃分validation set的效果是我目前取得的最好的結果,高達0.994,leaderboard上結果如下
linux下的具體過程:
具體過程:將上面的程式碼直接放到自己的caffe根目錄下(注意是caffe根目錄),然後再caffe根目錄下輸入命令列 make all(這樣做直接就可以對cpp文件編譯了)。將會在build/examples/newmnist資料夾下生成幾個convert_data_to_lmdb開頭的編譯結果檔案。然後繼續在caffe根目錄下輸入命令列: ./build/examples/newmnist/convert_data_to_lmdb.bin examples/newmnist/test.csv examples/newmnist/mnist_test_lmdb --backend=lmdbconvert_data_to_lmdb
162 | new | Mokii | 0.99100 |
step1:lmdb格式資料生成
1)仿照caffe自帶的examples中的mnist,新建一個examples/newmnist。從比賽官網https://www.kaggle.com/c/digit-recognizer/data下載訓練、測試資料,examples/newmnist/test.csv,examples/newmnist/train.csv。
2)使用下面的c++程式將csv格式轉換為lmdb格式。convert_data_to_lmdb.cpp
#include <iostream> #include <string> #include <sstream> #include <gflags/gflags.h> #include "boost/scoped_ptr.hpp" #include "gflags/gflags.h" #include "glog/logging.h" #include "caffe/proto/caffe.pb.h" #include "caffe/util/db.hpp" #include "caffe/util/io.hpp" #include "caffe/util/rng.hpp" using namespace caffe; using namespace std; using std::pair; using boost::scoped_ptr; /* edited by Zack * argv[1] the input file, argv[2] the output file*/ DEFINE_string(backend, "lmdb", "The backend for storing the result"); // get Flags_backend == lmdb int main(int argc, char **argv){ ::google::InitGoogleLogging(argv[0]); #ifndef GFLAGS_GFLAGS_H_ namespace gflags = google; #endif if(argc < 3){ LOG(ERROR)<< "please check the input arguments!"; return 1; } ifstream infile(argv[1]); if(!infile){ LOG(ERROR)<< "please check the input arguments!"; return 1; } string str; int count = 0; int rows = 28; int cols = 28; unsigned char *buffer = new unsigned char[rows*cols]; stringstream ss; Datum datum; // this data structure store the data and label datum.set_channels(1); // the channels datum.set_height(rows); // rows datum.set_width(cols); // cols scoped_ptr<db::DB> db(db::GetDB(FLAGS_backend)); // new DB object db->Open(argv[2], db::NEW); // open the lmdb file to store the data scoped_ptr<db::Transaction> txn(db->NewTransaction()); // new Transaction object to put and commit the data const int kMaxKeyLength = 256; // to save the key char key_cstr[kMaxKeyLength]; bool flag= false; while(getline(infile, str)){ if(flag == false){ flag = true; continue; } int beg = 0; int end = 0; int str_index = 0; //test need to add this----------1 //datum.set_label(0); while((end = str.find_first_of(',', beg)) != string::npos){ //cout << end << endl; string dig_str = str.substr(beg, end - beg); int pixes; ss.clear(); ss << dig_str; ss >> pixes; // test need to delete this--------------2 if(beg == 0){ datum.set_label(pixes); beg = ++ end; continue; } buffer[str_index++] = (unsigned char)pixes; beg = ++end; } string dig_str = str.substr(beg); int pixes; ss.clear(); ss << dig_str; ss >> pixes; buffer[str_index++] = (unsigned char)pixes; datum.set_data(buffer, rows*cols); int length = snprintf(key_cstr, kMaxKeyLength, "%08d", count); // Put in db string out; CHECK(datum.SerializeToString(&out)); // serialize to string txn->Put(string(key_cstr, length), out); // put it, both the key and value if (++count % 1000 == 0) { // to commit every 1000 iteration // Commit db txn->Commit(); txn.reset(db->NewTransaction()); LOG(ERROR) << "Processed " << count << " files."; } } // write the last batch if (count % 1000 != 0) { // commit the last batch txn->Commit(); LOG(ERROR) << "Processed " << count << " files."; } return 0; }
具體過程:將上面的程式碼直接放到自己的caffe根目錄下(注意是caffe根目錄),然後再caffe根目錄下輸入命令列 make all(這樣做直接就可以對cpp文件編譯了)。將會在build/examples/newmnist資料夾下生成幾個convert_data_to_lmdb開頭的編譯結果檔案。然後繼續在caffe根目錄下輸入命令列: ./build/examples/newmnist/convert_data_to_lmdb.bin examples/newmnist/test.csv examples/newmnist/mnist_test_lmdb --backend=lmdbconvert_data_to_lmdb
./build/examples/newmnist/convert_data_to_lmdb.bin examples/newmnist/test.csv examples/newmnist/mnist_test_lmdb --backend=lmdbconvert_data_to_lmdb
執行完將會在examples/newmnist下生成mnist_test_lmdb和mnist_train_lmdb兩個資料夾。