1. 程式人生 > >Caffe 原始碼閱讀筆記 [基本模組] Solver

Caffe 原始碼閱讀筆記 [基本模組] Solver

概述

前面介紹了網路Net模組來描述一個神經網路,Solver負責生成訓練網路和測試網路並按照演算法對網路進行引數優化。跟Layer一樣,Caffe把Solver實現成一個介面,使得開發者可以開發自己的Solver,其中Solver的子類需要實現ApplyUpdate函式來根據當前網路的狀態對網路引數進行更新。我們可以通過SolverParameter給Solver配置一個訓練網路和若干個測試網路。測試網路的作用是我們每進行若干次訓練和引數更新之後,我們可以把當前的網路引數用在測試網路上測試效果。Solver也提供了介面使得客戶端可以發出指令要求Solver終止訓練(如果發現在測試集上效果變差)或者做快照(以便下次可以從快照開始繼續訓練而不用重新開始)。

Solver成員變數

  SolverParameter param_; // Solver引數
  int iter_; // 第幾次迭代
  int current_step_; // 
  shared_ptr<Net<Dtype> > net_; // 要優化的網路
  vector<shared_ptr<Net<Dtype> > > test_nets_; // 用於測試的網路
  vector<Dtype> losses_; // 儲存最後average_loss次迭代的loss值。
  Dtype smoothed_loss_; // 當前的最後average_loss次迭代的loss的平均值
const Solver* const root_solver_; // 它包含了root_net(有共享Layer的網路) ActionCallback action_request_function_; // 客戶端可以通過這個callback來要求Solver做一個snapshot或者退出 vector<Callback*> callbacks_;

Solver初始化函式Init

void Init(const SolverParameter& param) {
  param_ = param;
  InitTrainNet(); // 初始化訓練網路
if (Caffe::root_solver()) { InitTestNets(); // 初始化測試網路 } iter_ = 0; current_step_ = 0; } void InitTrainNet() { NetParameter net_param; // 從Solver的引數中拷貝網路引數,也可以從檔案param_.train_net()、param_.net_param()和param_.net()裡讀出來 net_param.CopyFrom(param_.train_net_param()); NetState net_state; // 從優先順序由低到高設定網路的狀態 net_state.set_phase(TRAIN); net_state.MergeFrom(net_param.state()); net_state.MergeFrom(param_.train_state()); net_param.mutable_state()->CopyFrom(net_state); if (Caffe::root_solver()) { // 構造一個新的root網路 net_.reset(new Net<Dtype>(net_param)); } else { // 構造一個新的non-root網路,有一部分網路是基於給定的root網路的 net_.reset(new Net<Dtype>(net_param, root_solver_->net_.get())); } } void Solver<Dtype>::InitTestNets() { int test_net_id = 0; vector<string> sources(num_test_net_instances); vector<NetParameter> net_params(num_test_net_instances); // 從test_net_param得到所有Test網路的設定,也可以從net_param, net檔案,test_net檔案裡讀出來 for (int i = 0; i < num_test_net_params; ++i, ++test_net_id) { sources[test_net_id] = "test_net_param"; net_params[test_net_id].CopyFrom(param_.test_net_param(i)); } test_nets_.resize(num_test_net_instances); for (int i = 0; i < num_test_net_instances; ++i) { // 從優先順序由低到高設定網路的狀態 NetState net_state; net_state.set_phase(TEST); net_state.MergeFrom(net_params[i].state()); if (param_.test_state_size()) { net_state.MergeFrom(param_.test_state(i)); } net_params[i].mutable_state()->CopyFrom(net_state); if (Caffe::root_solver()) { // 構造新的root網路 test_nets_[i].reset(new Net<Dtype>(net_params[i])); } else { // 構造新的non-root網路 test_nets_[i].reset(new Net<Dtype>(net_params[i], root_solver_->test_nets_[i].get())); } test_nets_[i]->set_debug_info(param_.debug_info()); } }

Solver的Solve函式

void Solver<Dtype>::Solve(const char* resume_file) {
  CHECK(Caffe::root_solver()); // 只有root_solver才能優化網路
  requested_early_exit_ = false; // 如果被設為true,則退出
  if (resume_file) {
    // 從儲存好的snapshot開始,而不是從頭開始
    Restore(resume_file);
  }
  int start_iter = iter_;
  // 迭代到最大允許次數
  Step(param_.max_iter() - iter_);
  if (requested_early_exit_) { return; }
  if (param_.display() && iter_ % param_.display() == 0) {
    // average_loss控制我們計算loss值是最後average_loss次迭代的平均值
    int average_loss = this->param_.average_loss();
    Dtype loss;
    net_->Forward(&loss); // 前向傳播計算loss值
    UpdateSmoothedLoss(loss, start_iter, average_loss); // 通過當前smoothed_loss_和loss_陣列計算最後average_loss次迭代loss值的平均值
  }
  if (param_.test_interval() && iter_ % param_.test_interval() == 0) {
    // 每test_interval次迭代計算一下在測試集的效果。
    TestAll();
  }
}

// 迭代iters次
void Solver<Dtype>::Step(int iters) {
  const int start_iter = iter_;
  const int stop_iter = iter_ + iters;
  int average_loss = this->param_.average_loss();
  losses_.clear();
  smoothed_loss_ = 0;

  while (iter_ < stop_iter) {
    // 把引數清空
    net_->ClearParamDiffs();
    // 呼叫callback
    for (int i = 0; i < callbacks_.size(); ++i) {
      callbacks_[i]->on_start();
    }
    Dtype loss = 0;
    for (int i = 0; i < param_.iter_size(); ++i) {
      // 對網路先進行前向傳播,再進行反向傳播。然後計算loss的總和以求平均值
      loss += net_->ForwardBackward();
    }
    loss /= param_.iter_size(); //求平均值
    // 計算最後average_loss次迭代的平均loss值smoothed_loss_
    UpdateSmoothedLoss(loss, start_iter, average_loss);
    if (display) {
      // 列印網路輸出blob的值,略
    }
    // 呼叫callback
    for (int i = 0; i < callbacks_.size(); ++i) {
      callbacks_[i]->on_gradients_ready(); 
    }
    ApplyUpdate(); //由Solver的子類實現來更新網路引數
    ++iter_;
    // 獲得SolverAction,根據客戶端要求可以做Snapshot或者提前退出
    SolverAction::Enum request = GetRequestedAction();
  }
}

Solver的Test函式

// 測試所有的資料
void TestAll() {
  for (int test_net_id = 0;
       test_net_id < test_nets_.size() && !requested_early_exit_;
       ++test_net_id) {
    Test(test_net_id);
  }
}
// 測試一個數據集
void Solver<Dtype>::Test(const int test_net_id) {
  vector<Dtype> test_score; // 儲存網路輸出blob的所有值
  vector<int> test_score_output_id; // 儲存test_score[i]對應的top blob的id
  const shared_ptr<Net<Dtype> >& test_net = test_nets_[test_net_id];
  Dtype loss = 0;
  // 做test_iter次計算,取平均loss值
  for (int i = 0; i < param_.test_iter(test_net_id); ++i) {
    // 根據SolverAction決定是否做snapshot或者退出程式,略
    SolverAction::Enum request = GetRequestedAction();
    Dtype iter_loss;
    // 做前向傳播,計算loss值
    const vector<Blob<Dtype>*>& result = test_net->Forward(&iter_loss);
    // 如果計算loss,那麼累加每次迭代的loss值
    if (param_.test_compute_loss()) {
      loss += iter_loss;
    }
    if (i == 0) {
      for (int j = 0; j < result.size(); ++j) {
        const Dtype* result_vec = result[j]->cpu_data();
        for (int k = 0; k < result[j]->count(); ++k) {
          // 把所有top blob打平到一維陣列test_score上, 並在test_score_output_id記錄對應的blob id
          test_score.push_back(result_vec[k]);
          test_score_output_id.push_back(j);
        }
      }
    } else {
      int idx = 0;
      for (int j = 0; j < result.size(); ++j) {
        const Dtype* result_vec = result[j]->cpu_data();
        for (int k = 0; k < result[j]->count(); ++k) {
          // 累加所有top blob對應的元素,最後是要再除以test_iter以求平均值的
          test_score[idx++] += result_vec[k];
        }
      }
    }
  }
  if (param_.test_compute_loss()) {
    // 計算平均值並列印
    loss /= param_.test_iter(test_net_id);
  }
  for (int i = 0; i < test_score.size(); ++i) {
    // 對每個test_score,得到它對應的top blob的名字和loss權重
    const int output_blob_index =
        test_net->output_blob_indices()[test_score_output_id[i]];
    const string& output_name = test_net->blob_names()[output_blob_index];
    const Dtype loss_weight = test_net->blob_loss_weights()[output_blob_index];
    ostringstream loss_msg_stream;
    // 求出平均的test_score
    const Dtype mean_score = test_score[i] / param_.test_iter(test_net_id);
    // 如果loss_weight不為0,則計算加權的mean score。
    if (loss_weight) {
      loss_msg_stream << " (* " << loss_weight
                      << " = " << loss_weight * mean_score << " loss)";
    }
  }
}