Caffemodel資料結構解析與Protocol Buffer技術詳解(C++例項)
Caffe中,資料的讀取、運算、儲存都是採用Google Protocol Buffer來進行的,所以首先來較為詳細的介紹下Protocol Buffer(PB)。
PB是一種輕便、高效的結構化資料儲存格式,可以用於結構化資料序列化,很適合做資料儲存或 RPC 資料交換格式。它可用於通訊協議、資料儲存等領域的語言無關、平臺無關、可擴充套件的序列化結構資料格式。是一種效率和相容性都很優秀的二進位制資料傳輸格式,目前提供了 C++、Java、Python 三種語言的 API。Caffe採用的是C++和Python的API。
接下來,我用一個簡單的例子來說明一下。
使用PB和 C++ 編寫一個十分簡單的例子程式。該程式由兩部分組成。第一部分被稱為Writer,第二部分叫做Reader。Writer 負責將一些結構化的資料寫入一個磁碟檔案,Reader則負責從該磁碟檔案中讀取結構化資料並列印到螢幕上。準備用於演示的結構化資料是HelloWorld,它包含兩個基本資料:
ID,為一個整數型別的資料;
Str,這是一個字串。
首先我們需要編寫一個proto檔案,定義我們程式中需要處理的結構化資料,Caffe是定義在caffe.proto檔案中。在PB的術語中,結構化資料被稱為 Message。proto檔案非常類似java或C語言的資料定義。程式碼清單 1 顯示了例子應用中的proto檔案內容。
清單 1. proto 檔案
1 package lm; 2 3 message helloworld 4 5 { 6 7 required int32 id = 1; // ID 8 9 required stringView Codestr = 2; // str 10 11 optional int32 opt = 3; // optional field 12 13 }
一個比較好的習慣是認真對待proto檔案的檔名。比如將命名規則定於如下: packageName.MessageName.proto
在上例中,package名字叫做 lm,定義了一個訊息helloworld,該訊息有三個成員,型別為int32的id,另一個為型別為string的成員str。optional是一個可選的成員,即訊息中可以不包含該成員,required表明是必須包含該成員。一般在定義中會出現如下三個欄位屬性:
對於required的欄位而言,初值是必須要提供的,否則欄位的便是未初始化的。 在Debug模式的buffer庫下編譯的話,序列化話的時候可能會失敗,而且在反序列化的時候對於該欄位的解析會總是失敗的。所以,對於修飾符為required的欄位,請在序列化的時候務必給予初始化。
對於optional的欄位而言,如果未進行初始化,那麼一個預設值將賦予該欄位,當然也可以指定預設值。
對於repeated的欄位而言,該欄位可以重複多個,谷歌提供的這個 addressbook例子便有個很好的該修飾符的應用場景,即每個人可能有多個電話號碼。在高階語言裡面,我們可以通過陣列來實現,而在proto定義檔案中可以使用repeated來修飾,從而達到相同目的。當然,出現0次也是包含在內的。
寫好proto檔案之後就可以用PB編譯器(protoc)將該檔案編譯成目標語言了。本例中我們將使用C++。假設proto檔案存放在 $SRC_DIR 下面,您也想把生成的檔案放在同一個目錄下,則可以使用如下命令:
1 protoc -I=$SRC_DIR --cpp_out=$DST_DIR $SRC_DIR/addressbook.proto
View Code
命令將生成兩個檔案:
lm.helloworld.pb.h, 定義了C++ 類的標頭檔案;
lm.helloworld.pb.cc,C++類的實現檔案。
在生成的標頭檔案中,定義了一個 C++ 類 helloworld,後面的 Writer 和 Reader 將使用這個類來對訊息進行操作。諸如對訊息的成員進行賦值,將訊息序列化等等都有相應的方法。
如前所述,Writer將把一個結構化資料寫入磁碟,以便其他人來讀取。假如我們不使用 PB,其實也有許多的選擇。一個可能的方法是將資料轉換為字串,然後將字串寫入磁碟。轉換為字串的方法可以使用 sprintf(),這非常簡單。數字 123 可以變成字串”123”。這樣做似乎沒有什麼不妥,但是仔細考慮一下就會發現,這樣的做法對寫Reader的那個人的要求比較高,Reader 的作者必須瞭解Writer 的細節。比如”123”可以是單個數字 123,但也可以是三個數字 1、2 和 3等等。這麼說來,我們還必須讓Writer定義一種分隔符一樣的字元,以便Reader可以正確讀取。但分隔符也許還會引起其他的什麼問題。最後我們發現一個簡單的Helloworld 也需要寫許多處理訊息格式的程式碼。
如果使用 PB,那麼這些細節就可以不需要應用程式來考慮了。使用PB,Writer 的工作很簡單,需要處理的結構化資料由 .proto 檔案描述,經過上一節中的編譯過程後,該資料化結構對應了一個 C++ 的類,並定義在 lm.helloworld.pb.h 中。對於本例,類名為lm::helloworld。
Writer 需要include該標頭檔案,然後便可以使用這個類了。現在,在Writer程式碼中,將要存入磁碟的結構化資料由一個lm::helloworld類的物件表示,它提供了一系列的 get/set 函式用來修改和讀取結構化資料中的資料成員,或者叫field。
當我們需要將該結構化資料儲存到磁碟上時,類 lm::helloworld 已經提供相應的方法來把一個複雜的資料變成一個位元組序列,我們可以將這個位元組序列寫入磁碟。
對於想要讀取這個資料的程式來說,也只需要使用類 lm::helloworld 的相應反序列化方法來將這個位元組序列重新轉換會結構化資料。這同我們開始時那個“123”的想法類似,不過PB想的遠遠比我們那個粗糙的字串轉換要全面,因此,我們可以放心將這類事情交給PB吧。程式清單 2 演示了 Writer 的主要程式碼。
清單 2. Writer 的主要程式碼
1 #include "lm.helloworld.pb.h" 2 3 … 4 5 int main(void) 6 7 { 8 9 lm::helloworld msg1; 10 11 msg1.set_id(101); //設定id 12 13 msg1.set_str(“hello”); //設定str 14 15 // 向磁碟中寫入資料流fstream 16 17 fstream output("./log", ios::out | ios::trunc | ios::binary); 18 19 if (!msg1.SerializeToOstream(&output)) { 20 21 cerr << "Failed to write msg." << endl; 22 23 return -1; 24 25 } 26 27 return 0; 28 29 }View Code
Msg1 是一個helloworld類的物件,set_id()用來設定id的值。SerializeToOstream將物件序列化後寫入一個fstream流。我們可以寫出Reader程式碼,程式清單3列出了 reader 的主要程式碼。
清單 3. Reader的主要程式碼
1 #include "lm.helloworld.pb.h" 2 3 … 4 5 void ListMsg(const lm::helloworld & msg) { 6 7 cout << msg.id() << endl; 8 9 cout << msg.str() << endl; 10 11 } 12 13 int main(int argc, char* argv[]) { 14 15 lm::helloworld msg1; 16 17 { 18 19 fstream input("./log", ios::in | ios::binary); 20 21 if (!msg1.ParseFromIstream(&input)) { 22 23 cerr << "Failed to parse address book." << endl; 24 25 return -1; 26 27 } 28 29 } 30 31 ListMsg(msg1); 32 33 … 34 35 }View Code
同樣,Reader 宣告類helloworld的物件msg1,然後利用ParseFromIstream從一個fstream流中讀取資訊並反序列化。此後,ListMsg中採用get方法讀取訊息的內部資訊,並進行列印輸出操作。
執行Writer和Reader的結果如下:
>writer
>reader
101
Hello
Reader 讀取檔案 log 中的序列化資訊並列印到螢幕上。這個例子本身並無意義,但只要稍加修改就可以將它變成更加有用的程式。比如將磁碟替換為網路 socket,那麼就可以實現基於網路的資料交換任務。而儲存和交換正是PB最有效的應用領域。
到這裡為止,我們只給出了一個簡單的沒有任何用處的例子。在實際應用中,人們往往需要定義更加複雜的 Message。我們用“複雜”這個詞,不僅僅是指從個數上說有更多的 fields 或者更多型別的 fields,而是指更加複雜的資料結構:巢狀 Message,Caffe.proto檔案中定義了大量的巢狀Message。使得Message的表達能力增強很多。程式碼清單 4 給出一個巢狀 Message 的例子。
清單 4. 巢狀 Message 的例子
1 message Person { 2 required string name = 1; 3 required int32 id = 2; // Unique ID number for this person. 4 optional string email = 3; 5 enum PhoneType { 6 MOBILE = 0; 7 HOME = 1; 8 WORK = 2; 9 } 10 11 message PhoneNumber { 12 required string number = 1; 13 optional PhoneType type = 2 [default = HOME]; 14 } 15 repeated PhoneNumber phone = 4; 16 }View Code
在 Message Person 中,定義了巢狀訊息 PhoneNumber,並用來定義 Person 訊息中的 phone 域。這使得人們可以定義更加複雜的資料結構。
在Caffe中也是類似於上例中的Writer和Reader去讀寫PB資料的。接下來,具體說明下Caffe中是如何儲存Caffemodel的。在Caffe主目錄下的solver.cpp檔案中的一段程式碼展示了Caffe是如何儲存Caffemodel的,程式碼清單5如下:
清單 5. Caffemodel儲存程式碼
1 template <typename Dtype> 2 3 void Solver<Dtype>::Snapshot() { 4 5 NetParameter net_param; // NetParameter為網路引數類 6 7 // 為了中間結果,也會寫入梯度值 8 9 net_->ToProto(&net_param, param_.snapshot_diff()); 10 11 string filename(param_.snapshot_prefix()); 12 13 string model_filename, snapshot_filename; 14 15 const int kBufferSize = 20; 16 17 char iter_str_buffer[kBufferSize]; 18 19 // 每訓練完1次,iter_就加1 20 21 snprintf(iter_str_buffer, kBufferSize, "_iter_%d", iter_ + 1); 22 23 filename += iter_str_buffer; 24 25 model_filename = filename + ".caffemodel"; //XX_iter_YY.caffemodel 26 27 LOG(INFO) << "Snapshotting to " << model_filename; 28 29 // 向磁碟寫入網路引數 30 31 WriteProtoToBinaryFile(net_param, model_filename.c_str()); 32 33 SolverState state; 34 35 SnapshotSolverState(&state); 36 37 state.set_iter(iter_ + 1); //set 38 39 state.set_learned_net(model_filename); 40 41 state.set_current_step(current_step_); 42 43 snapshot_filename = filename + ".solverstate"; 44 45 LOG(INFO) << "Snapshotting solver state to " << snapshot_filename; 46 47 // 向磁碟寫入網路state 48 49 WriteProtoToBinaryFile(state, snapshot_filename.c_str()); 50 51 }View Code
在清單5程式碼中,我們可以看到,其實Caffemodel儲存的資料也就是網路引數net_param的PB,Caffe可以儲存每一次訓練完成後的網路引數,我們可以通過XX.prototxt檔案來進行引數設定。在這裡的 WriteProtoToBinaryFile函式與之前HelloWorld例子中的Writer函式類似,在這就不在貼出。那麼我們只要弄清楚NetParameter類的組成,也就明白了Caffemodel的具體資料構成。在caffe.proto這個檔案中定義了NetParameter類,如程式碼清單6所示。
清單6. Caffemodel儲存程式碼
1 message NetParameter { 2 3 optional string name = 1; // 網路名稱 4 5 repeated string input = 3; // 網路輸入input blobs 6 7 repeated BlobShape input_shape = 8; // The shape of the input blobs 8 9 // 輸入維度blobs,4維(num, channels, height and width) 10 11 repeated int32 input_dim = 4; 12 13 // 網路是否強制每層進行反饋操作開關 14 15 // 如果設定為False,則會根據網路結構和學習率自動確定是否進行反饋操作 16 17 optional bool force_backward = 5 [default = false]; 18 19 // 網路的state,部分網路層依賴,部分不依賴,需要看具體網路 20 21 optional NetState state = 6; 22 23 // 是否列印debug log 24 25 optional bool debug_info = 7 [default = false]; 26 27 // 網路層引數,Field Number 為100,所以網路層引數在最後 28 29 repeated LayerParameter layer = 100; 30 31 // 棄用: 用 'layer' 代替 32 33 repeated V1LayerParameter layers = 2; 34 35 } 36 37 // Specifies the shape (dimensions) of a Blob. 38 39 message BlobShape { 40 41 repeated int64 dim = 1 [packed = true]; 42 43 } 44 45 message BlobProto { 46 47 optional BlobShape shape = 7; 48 49 repeated float data = 5 [packed = true]; 50 51 repeated float diff = 6 [packed = true]; 52 53 optional int32 num = 1 [default = 0]; 54 55 optional int32 channels = 2 [default = 0]; 56 57 optional int32 height = 3 [default = 0]; 58 59 optional int32 width = 4 [default = 0]; 60 61 } 62 63 64 65 // The BlobProtoVector is simply a way to pass multiple blobproto instances 66 67 around. 68 69 message BlobProtoVector { 70 71 repeated BlobProto blobs = 1; 72 73 } 74 75 message NetState { 76 77 optional Phase phase = 1 [default = TEST]; 78 79 optional int32 level = 2 [default = 0]; 80 81 repeated string stage = 3; 82 83 } 84 85 message LayerParameter { 86 87 optional string name = 1; // the layer name 88 89 optional string type = 2; // the layer type 90 91 repeated string bottom = 3; // the name of each bottom blob 92 93 repeated string top = 4; // the name of each top blob 94 95 // The train/test phase for computation. 96 97 optional Phase phase = 10; 98 99 // Loss weight值:float 100 101 // 每一層為每一個top blob都分配了一個預設值,通常是0或1 102 103 repeated float loss_weight = 5; 104 105 // 指定的學習引數 106 107 repeated ParamSpec param = 6; 108 109 // The blobs containing the numeric parameters of the layer. 110 111 repeated BlobProto blobs = 7; 112 113 // included/excluded. 114 115 repeated NetStateRule include = 8; 116 117 repeated NetStateRule exclude = 9; 118 119 // Parameters for data pre-processing. 120 121 optional TransformationParameter transform_param = 100; 122 123 // Parameters shared by loss layers. 124 125 optional LossParameter loss_param = 101; 126 127 // 各種型別層引數 128 129 optional AccuracyParameter accuracy_param = 102; 130 131 optional ArgMaxParameter argmax_param = 103; 132 133 optional ConcatParameter concat_param = 104; 134 135 optional ContrastiveLossParameter contrastive_loss_param = 105; 136 137 optional ConvolutionParameter convolution_param = 106; 138 139 optional DataParameter data_param = 107; 140 141 optional DropoutParameter dropout_param = 108; 142 143 optional DummyDataParameter dummy_data_param = 109; 144 145 optional EltwiseParameter eltwise_param = 110; 146 147 optional ExpParameter exp_param = 111; 148 149 optional HDF5DataParameter hdf5_data_param = 112; 150 151 optional HDF5OutputParameter hdf5_output_param = 113; 152 153 optional HingeLossParameter hinge_loss_param = 114; 154 155 optional ImageDataParameter image_data_param = 115; 156 157 optional InfogainLossParameter infogain_loss_param = 116; 158 159 optional InnerProductParameter inner_product_param = 117; 160 161 optional LRNParameter lrn_param = 118; 162 163 optional MemoryDataParameter memory_data_param = 119; 164 165 optional MVNParameter mvn_param = 120; 166 167 optional PoolingParameter pooling_param = 121; 168 169 optional PowerParameter power_param = 122; 170 171 optional PythonParameter python_param = 130; 172 173 optional ReLUParameter relu_param = 123; 174 175 optional SigmoidParameter sigmoid_param = 124; 176 177 optional SoftmaxParameter softmax_param = 125; 178 179 optional SliceParameter slice_param = 126; 180 181 optional TanHParameter tanh_param = 127; 182 183 optional ThresholdParameter threshold_param = 128; 184 185 optional WindowDataParameter window_data_param = 129; 186 187 }View Code
那麼接下來的一段程式碼來演示如何解析Caffemodel,我解析用的model為MNIST手寫庫訓練後的model,Lenet_iter_10000.caffemodel。
清單7. Caffemodel解析程式碼
1 #include <stdio.h> 2 #include <string.h> 3 #include <fstream> 4 #include <iostream> 5 #include "proto/caffe.pb.h" 6 7 using namespace std; 8 using namespace caffe; 9 10 int main(int argc, char* argv[]) 11 { 12 13 caffe::NetParameter msg; 14 15 fstream input("lenet_iter_10000.caffemodel", ios::in | ios::binary); 16 if (!msg.ParseFromIstream(&input)) 17 { 18 cerr << "Failed to parse address book." << endl; 19 return -1; 20 } 21 printf("length = %d\n", length); 22 printf("Repeated Size = %d\n", msg.layer_size()); 23 24 ::google::protobuf::RepeatedPtrField< LayerParameter >* layer = msg.mutable_layer(); 25 ::google::protobuf::RepeatedPtrField< LayerParameter >::iterator it = layer->begin(); 26 for (; it != layer->end(); ++it) 27 { 28 cout << it->name() << endl; 29 cout << it->type() << endl; 30 cout << it->convolution_param().weight_filler().max() << endl; 31 } 32 33 return 0; 34 }
View Code
這篇Blog仍然是以Google的官方文件為主線,程式碼例項則完全取自於我們正在開發的一個Demo專案,通過前一段時間的嘗試,感覺這種結合的方式比較有利於培訓和內部的技術交流。還是那句話,沒有最好的,只有