caffe隱藏網路結構的定義,prototxt檔案(待續)
阿新 • • 發佈:2019-01-07
在caffe中,在網路訓練和網路使用的過程中,都會涉及到載入網路結構的問題,即caffe中定義的prototxt檔案。但該檔案包括了幾乎所有的網路資訊,能否做到隱藏該檔案呢?
能夠想到的有兩種方式,
1、加密該檔案,使用時再解密該檔案
2、將網路結構使用c++程式碼實現
其中方式2沒有做太多的研究,應該比較麻煩。
那麼加密解密的方式如何實現呢?採用合適的加密演算法,解密後變成string,將string解析,變成網路結構的引數,進行處理。
這種方式,會儲存有加密檔案,但並不想包含這樣一個檔案,而且需要設計加密演算法,這個如何設計呢?
於是,我想到了另外一種方式,把網路結構的定義,用C++程式碼一行一行的實現,即用C++的方式寫一個prototxt檔案,將該檔案儲存為臨時檔案,使用該臨時檔案後,馬上刪除該檔案。
當然這種方式仍然有一定的風險,就該臨時檔案的存在,那麼能不能把加密的方式和程式碼實現網路結構的方式相結合呢?其實用程式碼實現網路結構只是替代了加密的過程,用程式碼生成網路結構後,不生成prototxt檔案,而是在記憶體中以string的方式存在,把此string當做引數傳入即可
想要隱藏網路結構,對於訓練階段和測試階段還不太一樣,測試階段只有一個deploy.prototxt檔案,而且針對測試階段,前面是自己寫了c++的程式碼讀取deploy.prototxt,載入模型的。改改這裡就好了。
但訓練階段,是呼叫的caffe的可執行檔案,那麼如果要更改傳參方式,就需要修改caffe的原始碼,主要是caffe/tools/caffe.cpp檔案,修改其中讀取train_val.prototxt和solver.prototxt檔案的部分。
首先實現測試階段的過程,其實在caffe原始碼中,讀取deploy.prototxt檔案也是通過解析該檔案,得到NetParameters,那麼通過解析string,得到NetParameters的方式,就也是可行的。
通過caffe原始碼一點點分析,prototxt檔案首先是通過是傳入net,在caffe/cpp_classification/classification中,有這樣一行程式碼
-Cpp 程式碼
1
net_.reset(new Net<float>(model_file, TEST));
其中model_file就是是deploy.prototxt檔案,通過建構函式,構建網路,Net類在caffe/src/caffe/net.cpp中,如下
-Cpp 程式碼
01
template <typename Dtype>
02
Net<Dtype>::Net(const string& param_file, Phase phase,
03
const int level, const vector<string>* stages) {
04
NetParameter param;
05
ReadNetParamsFromTextFileOrDie(param_file, ¶m);
06
// Set phase, stages and level
07
param.mutable_state()->set_phase(phase);
08
if (stages != NULL) {
09
for (int i = 0; i < stages->size(); i++) {
10
param.mutable_state()->add_stage((*stages)[i]);
11
}
12
}
13
param.mutable_state()->set_level(level);
14
Init(param);
15
}
這裡首先構建NetParameter param,在通過
-Cpp 程式碼
1
ReadNetParamsFromTextFileOrDie(param_file, ¶m);
從deploy.prototxt中解析出來網路結構,把網路結構賦值給param。
再進入到
-Cpp 程式碼
1
ReadNetParamsFromTextFileOrDie()
函式中,看下是如何操作的,該函式在caffe/src/caffe/util/upgrade_proto.cpp中,如下:
-Cpp 程式碼
1
void ReadNetParamsFromTextFileOrDie(const string& param_file,
2
NetParameter* param) {
3
CHECK(ReadProtoFromTextFile(param_file, param))
4
<< "Failed to parse NetParameter file: " << param_file;
5
UpgradeNetAsNeeded(param_file, param);
6
}
在該函式中又呼叫
-Cpp 程式碼
1
ReadProtoFromTextFile()
從deploy.prototxt檔案中讀取資料,然後網路結構儲存在param中。
再進入到ReadProtoFromTextFile()函式看下,其中的執行機制,該函式在caffe/src/caffe/util/io.cpp中,如下
-Cpp 程式碼
1
bool ReadProtoFromTextFile(const char* filename, Message* proto) {
2
int fd = open(filename, O_RDONLY);
3
CHECK_NE(fd, -1) << "File not found: " << filename;
4
FileInputStream* input = new FileInputStream(fd);
5
bool success = google::protobuf::TextFormat::Parse(input, proto);
6
delete input;
7
close(fd);
8
return success;
9
}z
在這裡,首先是通過讀檔案的方式,將deploy.prototxt讀出來,然後通過protobuf的parse解析出來即可。
如果需要需要自己實現一個從string中解析網路結構的功能,那麼前面的大部分內容都可以照搬,最後再呼叫protobuf的TextFormat::ParseFromString即可。
前面分析清楚了,那麼下面來試試如何修改caffe的原始碼(在修改的時候,別忘了修改.h檔案),首先是在io.cpp中需要增加一個函式(並在caffe/include/caffe/util/io.h中增加相應的宣告),如下
-Cpp 程式碼
1
bool ReadProtoFromString(const string &str, Message* proto){
2
bool success = google::protobuf::TextFormat::ParseFromString(str, proto);
3
return success;
4
5
}
然後在upgrade_proto.cpp中增加ReadNetParamsFromStringOrDie(同時修改.h檔案),如下
-Cpp 程式碼
1
void ReadNetParamsFromStringOrDie(const string& str, NetParameter* param){
2
3
const string output = "It's a string, not a file.";
4
CHECK(ReadProtoFromString(str, param))
5
<< "Failed to parse NetParameter from string" << output;
6
7
UpgradeNetAsNeeded(output,param);
8
}
這裡新增了一個output字串,主要是為了輸出,之前的param_file是檔案,而目前是字串,不適合輸出str。
另外這裡面的函式UpgradeNetAsNeeded,主要是根據param來判斷,output只是輸出用,故用output也是可以的,不會對output做任何的解析。
之後再在Net.cpp中做相應的修改
-Cpp 程式碼
01
template <typename Dtype>
02
Net<Dtype>::Net(int usage, const string& str, Phase phase, const int level, const vector<string>* stages){
03
NetParameter param;
04
ReadNetParamsFromStringOrDie(str, ¶m);
05
//Set phase, stages and level
06
param.mutable_state()->set_phase(phase);
07
if (stages != NULL){
08
for (int i = 0; i < stages->size(); i++) {
09
param.mutable_state()->add_stage((*stages)[i]);
10
}
11
}
12
param.mutable_state()->set_level(level);
13
Init(param);
14
15
}
第一個引數int,並無實際意義,僅僅用作和前面的建構函式進行區分。
現在一切準備就緒後,就需要開始通過string構建網路結構。
其實用string構建網路很簡單,只需要把deploy.prototxt中的字串提取出來,用一個字串表示即可,原本的回車換行符以"\n"替代即可,另外原始檔案中的雙引號可以用單引號替代
那麼現在就完成了對測試網路的deploy.prototxt檔案的隱藏,如果要對訓練的網路結構進行隱藏,則需要修改其它的程式碼。
首先,明確一下,訓練時用到的程式碼見caffe/tools/caffe.cpp中,編譯後,會在caffe/build/中找到可執行檔案caffe,因此可以用該可執行檔案進行網路的訓練。一般呼叫方式為
-Bash 程式碼
1
./build/caffe train solver.prototxt
這裡只傳入了solver.prototxt,該檔案也是需要隱藏的(另外,在該檔案中,給出了train_val.prototxt檔案的位置,這裡先介紹如何隱藏solver.prototxt檔案,再看看如何隱藏train_val.prototxt檔案)
在caffe.cpp檔案中,使用solver.prototxt檔案的程式碼如下:
-Cpp 程式碼
1
caffe::SolverParameter solver_param;
2
caffe::ReadSolverParamsFromTextFileOrDie(FLAGS_solver, &solver_param);
其中,FLAGS_solver表示solver.prototxt檔案,可見,這裡和前面隱藏deploy.prototxt檔案是一致的。做相應的修改。ReadSolverParamsFromTextFileOrDie()函式在caffe/src/caffe/util/upgrade_proto.cpp中,如下:
-Cpp 程式碼
1
void ReadSolverParamsFromTextFileOrDie(const string& param_file,
2
SolverParameter* param) {
3
CHECK(ReadProtoFromTextFile(param_file, param))
4
<< "Failed to parse SolverParameter file: " << param_file;
5
UpgradeSolverAsNeeded(param_file, param);
6
}
這裡同樣還是呼叫ReadProtoFromTextFile,那麼修改的時候,也是類似的,只需要再這裡增加一個函式ReadSolverParamsFromStringOrDie,在這個函式裡面呼叫的函式是和前面一樣的,這裡已經準備好了。如下:
-Cpp 程式碼
1
void ReadSolverParamsFromStringOrDie(const string& str, SolverParameter* param){
2
const string output = "It's a string, not a file.";
3
CHECK(ReadProtoFromString(str, param))
4
<< "Failed to parse NetParameter from string" << output;
5
UpgradeSolverAsNeeded(output,param);
6
}
再相應的修改.h檔案即可。
在solver.prototxt檔案中包含了train_val.prototxt檔案的位置,從此地址中獲取網路結構,那麼就還需要隱藏該網路結構。在caffe.cpp中從solver.prototxt讀取訓練的引數資訊後,儲存到solver_param物件中,再用該物件構建solver物件,
-Cpp 程式碼
1
caffe::ReadSolverParamsFromTextFileOrDie(FLAGS_solver, &solver_param);
2
shared_ptr<caffe::Solver<float> >
3
solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));
在構建solver物件的過程中,初始化了網路結構(包括訓練和測試的網路結構)
-Cpp 程式碼
1
InitTrainNet();
2
InitTestNets();
先看如何修改訓練的網路,改成從字串中讀取。初始訓練的網路程式碼如下:
-Cpp 程式碼
01
template <typename Dtype>
02
void Solver<Dtype>::InitTrainNet() {
03
const int num_train_nets = param_.has_net() + param_.has_net_param() +
04
param_.has_train_net() + param_.has_train_net_param();
05
const string& field_names = "net, net_param, train_net, train_net_param";
06
CHECK_GE(num_train_nets, 1) << "SolverParameter must specify a train net "
07
<< "using one of these fields: " << field_names;
08
CHECK_LE(num_train_nets, 1) << "SolverParameter must not contain more than "
09
<< "one of these fields specifying a train_net: " << field_names;
10
NetParameter net_param;
11
if (param_.has_train_net_param()) {
12
LOG_IF(INFO, Caffe::root_solver())
13
<< "Creating training net specified in train_net_param.";
14
net_param.CopyFrom(param_.train_net_param());
15
} else if (param_.has_train_net()) {
16
LOG_IF(INFO, Caffe::root_solver())
17
<< "Creating training net from train_net file: " << param_.train_net();
18
ReadNetParamsFromTextFileOrDie(param_.train_net(), &net_param);
19
}
20
if (param_.has_net_param()) {
21
LOG_IF(INFO, Caffe::root_solver())
22
<< "Creating training net specified in net_param.";
23
net_param.CopyFrom(param_.net_param());
24
}
25
if (param_.has_net()) {
26
LOG_IF(INFO, Caffe::root_solver())
27
<< "Creating training net from net file: " << param_.net();
28
ReadNetParamsFromTextFileOrDie(param_.net(), &net_param);
29
}
30
// Set the correct NetState. We start with the solver defaults (lowest
31
// precedence); then, merge in any NetState specified by the net_param itself;
32
// finally, merge in any NetState specified by the train_state (highest
33
// precedence).
34
NetState net_state;
35
net_state.set_phase(TRAIN);
36
net_state.MergeFrom(net_param.state());
37
net_state.MergeFrom(param_.train_state());
38
net_param.mutable_state()->CopyFrom(net_state);
39
net_.reset(new Net<Dtype>(net_param));
40
}
這裡會執行到if(param_.has_net())的條件裡,在這裡利用這個條件,但選擇從string中讀取網路結構。
首先在caffe/src/caffe/proto/caffe.proto檔案中的SolverParameter中增加bool型別的變量表明是從字串中讀取網路結構:
如下:
-Cpp 程式碼
1
//parse net from string
2
optional bool net_from_str = 101 [default = false];<br>
並在solver.prototxt對應的字串中增加對應的欄位,令其為true。
那麼在sovler.cpp中做一些修改
-Cpp 程式碼
01
if (param_.has_net()) {
02
LOG_IF(INFO, Caffe::root_solver())
03
<< "Creating training net from net file: " << param_.net();
04
if (param_.net_from_str()){
05
AppClsfyNet app_clsfy_net(301);
06
string str_net = app_clsfy_net.getTrainNetStr();
07
ReadNetParamsFromStringOrDie(str_net, &net_param);
08
std::cout<<str_net<<std::endl;
09
}else{
10
ReadNetParamsFromTextFileOrDie(param_.net(), &net_param);
11
}
12
}
需要在solver中先include自己寫的獲取train net的.h檔案,並從中獲取train net的字串。那麼到這裡,train_net搞定。那麼接下來看看如何初始化TestNet.
對於test net而言,也是一樣的,對於其中的程式碼做修改如下:
-Cpp 程式碼
01
if (has_net_file) {
02
for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) {
03
sources[test_net_id] = "net file: " + param_.net();
04
if(param_.net_from_str()){
05
AppClsfyNet app_clsfy_net(301);
06
string str_net = app_clsfy_net.getTrainNetStr();
07
ReadNetParamsFromStringOrDie(str_net, &net_params[test_net_id]);
08
}else{
09
ReadNetParamsFromTextFileOrDie(param_.net(), &net_params[test_net_id]);
10
}
11
}
12
}
好了,到此為止,基本是搞定了,網路結構都可以隱藏了,不管是訓練階段還是測試階段的網路結構都隱藏了。
能夠想到的有兩種方式,
1、加密該檔案,使用時再解密該檔案
2、將網路結構使用c++程式碼實現
其中方式2沒有做太多的研究,應該比較麻煩。
那麼加密解密的方式如何實現呢?採用合適的加密演算法,解密後變成string,將string解析,變成網路結構的引數,進行處理。
這種方式,會儲存有加密檔案,但並不想包含這樣一個檔案,而且需要設計加密演算法,這個如何設計呢?
於是,我想到了另外一種方式,把網路結構的定義,用C++程式碼一行一行的實現,即用C++的方式寫一個prototxt檔案,將該檔案儲存為臨時檔案,使用該臨時檔案後,馬上刪除該檔案。
當然這種方式仍然有一定的風險,就該臨時檔案的存在,那麼能不能把加密的方式和程式碼實現網路結構的方式相結合呢?其實用程式碼實現網路結構只是替代了加密的過程,用程式碼生成網路結構後,不生成prototxt檔案,而是在記憶體中以string的方式存在,把此string當做引數傳入即可
想要隱藏網路結構,對於訓練階段和測試階段還不太一樣,測試階段只有一個deploy.prototxt檔案,而且針對測試階段,前面是自己寫了c++的程式碼讀取deploy.prototxt,載入模型的。改改這裡就好了。
但訓練階段,是呼叫的caffe的可執行檔案,那麼如果要更改傳參方式,就需要修改caffe的原始碼,主要是caffe/tools/caffe.cpp檔案,修改其中讀取train_val.prototxt和solver.prototxt檔案的部分。
首先實現測試階段的過程,其實在caffe原始碼中,讀取deploy.prototxt檔案也是通過解析該檔案,得到NetParameters,那麼通過解析string,得到NetParameters的方式,就也是可行的。
通過caffe原始碼一點點分析,prototxt檔案首先是通過是傳入net,在caffe/cpp_classification/classification中,有這樣一行程式碼
-Cpp 程式碼
1
net_.reset(new Net<float>(model_file, TEST));
其中model_file就是是deploy.prototxt檔案,通過建構函式,構建網路,Net類在caffe/src/caffe/net.cpp中,如下
-Cpp 程式碼
01
template <typename Dtype>
02
Net<Dtype>::Net(const string& param_file, Phase phase,
03
const int level, const vector<string>* stages) {
04
NetParameter param;
05
ReadNetParamsFromTextFileOrDie(param_file, ¶m);
06
// Set phase, stages and level
07
param.mutable_state()->set_phase(phase);
08
if (stages != NULL) {
09
for (int i = 0; i < stages->size(); i++) {
10
param.mutable_state()->add_stage((*stages)[i]);
11
}
12
}
13
param.mutable_state()->set_level(level);
14
Init(param);
15
}
這裡首先構建NetParameter param,在通過
-Cpp 程式碼
1
ReadNetParamsFromTextFileOrDie(param_file, ¶m);
從deploy.prototxt中解析出來網路結構,把網路結構賦值給param。
再進入到
-Cpp 程式碼
1
ReadNetParamsFromTextFileOrDie()
函式中,看下是如何操作的,該函式在caffe/src/caffe/util/upgrade_proto.cpp中,如下:
-Cpp 程式碼
1
void ReadNetParamsFromTextFileOrDie(const string& param_file,
2
NetParameter* param) {
3
CHECK(ReadProtoFromTextFile(param_file, param))
4
<< "Failed to parse NetParameter file: " << param_file;
5
UpgradeNetAsNeeded(param_file, param);
6
}
在該函式中又呼叫
-Cpp 程式碼
1
ReadProtoFromTextFile()
從deploy.prototxt檔案中讀取資料,然後網路結構儲存在param中。
再進入到ReadProtoFromTextFile()函式看下,其中的執行機制,該函式在caffe/src/caffe/util/io.cpp中,如下
-Cpp 程式碼
1
bool ReadProtoFromTextFile(const char* filename, Message* proto) {
2
int fd = open(filename, O_RDONLY);
3
CHECK_NE(fd, -1) << "File not found: " << filename;
4
FileInputStream* input = new FileInputStream(fd);
5
bool success = google::protobuf::TextFormat::Parse(input, proto);
6
delete input;
7
close(fd);
8
return success;
9
}z
在這裡,首先是通過讀檔案的方式,將deploy.prototxt讀出來,然後通過protobuf的parse解析出來即可。
如果需要需要自己實現一個從string中解析網路結構的功能,那麼前面的大部分內容都可以照搬,最後再呼叫protobuf的TextFormat::ParseFromString即可。
前面分析清楚了,那麼下面來試試如何修改caffe的原始碼(在修改的時候,別忘了修改.h檔案),首先是在io.cpp中需要增加一個函式(並在caffe/include/caffe/util/io.h中增加相應的宣告),如下
-Cpp 程式碼
1
bool ReadProtoFromString(const string &str, Message* proto){
2
bool success = google::protobuf::TextFormat::ParseFromString(str, proto);
3
return success;
4
5
}
然後在upgrade_proto.cpp中增加ReadNetParamsFromStringOrDie(同時修改.h檔案),如下
-Cpp 程式碼
1
void ReadNetParamsFromStringOrDie(const string& str, NetParameter* param){
2
3
const string output = "It's a string, not a file.";
4
CHECK(ReadProtoFromString(str, param))
5
<< "Failed to parse NetParameter from string" << output;
6
7
UpgradeNetAsNeeded(output,param);
8
}
這裡新增了一個output字串,主要是為了輸出,之前的param_file是檔案,而目前是字串,不適合輸出str。
另外這裡面的函式UpgradeNetAsNeeded,主要是根據param來判斷,output只是輸出用,故用output也是可以的,不會對output做任何的解析。
之後再在Net.cpp中做相應的修改
-Cpp 程式碼
01
template <typename Dtype>
02
Net<Dtype>::Net(int usage, const string& str, Phase phase, const int level, const vector<string>* stages){
03
NetParameter param;
04
ReadNetParamsFromStringOrDie(str, ¶m);
05
//Set phase, stages and level
06
param.mutable_state()->set_phase(phase);
07
if (stages != NULL){
08
for (int i = 0; i < stages->size(); i++) {
09
param.mutable_state()->add_stage((*stages)[i]);
10
}
11
}
12
param.mutable_state()->set_level(level);
13
Init(param);
14
15
}
第一個引數int,並無實際意義,僅僅用作和前面的建構函式進行區分。
現在一切準備就緒後,就需要開始通過string構建網路結構。
其實用string構建網路很簡單,只需要把deploy.prototxt中的字串提取出來,用一個字串表示即可,原本的回車換行符以"\n"替代即可,另外原始檔案中的雙引號可以用單引號替代
那麼現在就完成了對測試網路的deploy.prototxt檔案的隱藏,如果要對訓練的網路結構進行隱藏,則需要修改其它的程式碼。
首先,明確一下,訓練時用到的程式碼見caffe/tools/caffe.cpp中,編譯後,會在caffe/build/中找到可執行檔案caffe,因此可以用該可執行檔案進行網路的訓練。一般呼叫方式為
-Bash 程式碼
1
./build/caffe train solver.prototxt
這裡只傳入了solver.prototxt,該檔案也是需要隱藏的(另外,在該檔案中,給出了train_val.prototxt檔案的位置,這裡先介紹如何隱藏solver.prototxt檔案,再看看如何隱藏train_val.prototxt檔案)
在caffe.cpp檔案中,使用solver.prototxt檔案的程式碼如下:
-Cpp 程式碼
1
caffe::SolverParameter solver_param;
2
caffe::ReadSolverParamsFromTextFileOrDie(FLAGS_solver, &solver_param);
其中,FLAGS_solver表示solver.prototxt檔案,可見,這裡和前面隱藏deploy.prototxt檔案是一致的。做相應的修改。ReadSolverParamsFromTextFileOrDie()函式在caffe/src/caffe/util/upgrade_proto.cpp中,如下:
-Cpp 程式碼
1
void ReadSolverParamsFromTextFileOrDie(const string& param_file,
2
SolverParameter* param) {
3
CHECK(ReadProtoFromTextFile(param_file, param))
4
<< "Failed to parse SolverParameter file: " << param_file;
5
UpgradeSolverAsNeeded(param_file, param);
6
}
這裡同樣還是呼叫ReadProtoFromTextFile,那麼修改的時候,也是類似的,只需要再這裡增加一個函式ReadSolverParamsFromStringOrDie,在這個函式裡面呼叫的函式是和前面一樣的,這裡已經準備好了。如下:
-Cpp 程式碼
1
void ReadSolverParamsFromStringOrDie(const string& str, SolverParameter* param){
2
const string output = "It's a string, not a file.";
3
CHECK(ReadProtoFromString(str, param))
4
<< "Failed to parse NetParameter from string" << output;
5
UpgradeSolverAsNeeded(output,param);
6
}
再相應的修改.h檔案即可。
在solver.prototxt檔案中包含了train_val.prototxt檔案的位置,從此地址中獲取網路結構,那麼就還需要隱藏該網路結構。在caffe.cpp中從solver.prototxt讀取訓練的引數資訊後,儲存到solver_param物件中,再用該物件構建solver物件,
-Cpp 程式碼
1
caffe::ReadSolverParamsFromTextFileOrDie(FLAGS_solver, &solver_param);
2
shared_ptr<caffe::Solver<float> >
3
solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));
在構建solver物件的過程中,初始化了網路結構(包括訓練和測試的網路結構)
-Cpp 程式碼
1
InitTrainNet();
2
InitTestNets();
先看如何修改訓練的網路,改成從字串中讀取。初始訓練的網路程式碼如下:
-Cpp 程式碼
01
template <typename Dtype>
02
void Solver<Dtype>::InitTrainNet() {
03
const int num_train_nets = param_.has_net() + param_.has_net_param() +
04
param_.has_train_net() + param_.has_train_net_param();
05
const string& field_names = "net, net_param, train_net, train_net_param";
06
CHECK_GE(num_train_nets, 1) << "SolverParameter must specify a train net "
07
<< "using one of these fields: " << field_names;
08
CHECK_LE(num_train_nets, 1) << "SolverParameter must not contain more than "
09
<< "one of these fields specifying a train_net: " << field_names;
10
NetParameter net_param;
11
if (param_.has_train_net_param()) {
12
LOG_IF(INFO, Caffe::root_solver())
13
<< "Creating training net specified in train_net_param.";
14
net_param.CopyFrom(param_.train_net_param());
15
} else if (param_.has_train_net()) {
16
LOG_IF(INFO, Caffe::root_solver())
17
<< "Creating training net from train_net file: " << param_.train_net();
18
ReadNetParamsFromTextFileOrDie(param_.train_net(), &net_param);
19
}
20
if (param_.has_net_param()) {
21
LOG_IF(INFO, Caffe::root_solver())
22
<< "Creating training net specified in net_param.";
23
net_param.CopyFrom(param_.net_param());
24
}
25
if (param_.has_net()) {
26
LOG_IF(INFO, Caffe::root_solver())
27
<< "Creating training net from net file: " << param_.net();
28
ReadNetParamsFromTextFileOrDie(param_.net(), &net_param);
29
}
30
// Set the correct NetState. We start with the solver defaults (lowest
31
// precedence); then, merge in any NetState specified by the net_param itself;
32
// finally, merge in any NetState specified by the train_state (highest
33
// precedence).
34
NetState net_state;
35
net_state.set_phase(TRAIN);
36
net_state.MergeFrom(net_param.state());
37
net_state.MergeFrom(param_.train_state());
38
net_param.mutable_state()->CopyFrom(net_state);
39
net_.reset(new Net<Dtype>(net_param));
40
}
這裡會執行到if(param_.has_net())的條件裡,在這裡利用這個條件,但選擇從string中讀取網路結構。
首先在caffe/src/caffe/proto/caffe.proto檔案中的SolverParameter中增加bool型別的變量表明是從字串中讀取網路結構:
如下:
-Cpp 程式碼
1
//parse net from string
2
optional bool net_from_str = 101 [default = false];<br>
並在solver.prototxt對應的字串中增加對應的欄位,令其為true。
那麼在sovler.cpp中做一些修改
-Cpp 程式碼
01
if (param_.has_net()) {
02
LOG_IF(INFO, Caffe::root_solver())
03
<< "Creating training net from net file: " << param_.net();
04
if (param_.net_from_str()){
05
AppClsfyNet app_clsfy_net(301);
06
string str_net = app_clsfy_net.getTrainNetStr();
07
ReadNetParamsFromStringOrDie(str_net, &net_param);
08
std::cout<<str_net<<std::endl;
09
}else{
10
ReadNetParamsFromTextFileOrDie(param_.net(), &net_param);
11
}
12
}
需要在solver中先include自己寫的獲取train net的.h檔案,並從中獲取train net的字串。那麼到這裡,train_net搞定。那麼接下來看看如何初始化TestNet.
對於test net而言,也是一樣的,對於其中的程式碼做修改如下:
-Cpp 程式碼
01
if (has_net_file) {
02
for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) {
03
sources[test_net_id] = "net file: " + param_.net();
04
if(param_.net_from_str()){
05
AppClsfyNet app_clsfy_net(301);
06
string str_net = app_clsfy_net.getTrainNetStr();
07
ReadNetParamsFromStringOrDie(str_net, &net_params[test_net_id]);
08
}else{
09
ReadNetParamsFromTextFileOrDie(param_.net(), &net_params[test_net_id]);
10
}
11
}
12
}
好了,到此為止,基本是搞定了,網路結構都可以隱藏了,不管是訓練階段還是測試階段的網路結構都隱藏了。