1. 程式人生 > >Caffe中權值是怎麼更新的

Caffe中權值是怎麼更新的

網址:http://blog.csdn.net/mounty_fsc/article/details/51588773

(Caffe,LeNet)權值更新(七)

在Solver::ApplyUpdate()函式中,根據反向傳播階段計算的loss關於網路權值的偏導,使用配置的學習策略,更新網路權值從而完成本輪學習

1 模型優化

1.1 損失函式

損失函式L(W)可由經驗損失加正則化項得到,如下,其中X(i)為輸入樣本;fW為某樣本的損失函式;N為mini-batch的樣本數量;r(W)為以權值為λ的正則項。

L(W)1NNifW(X(i))+λr(W)

在caffe中,可以分為三個階段:

  1. 前向計算階段,這個階段計算fW
  2. 反向傳播階段,這個階段計算fW
  3. 權值更新階段,這個階段通過fW,r(W)等計算ΔW從而更新W

1.2 隨機梯度下降

在lenet中,solver的型別為SGD(Stochastic gradient descent)

SGD通過以下公式對權值進行更新:

Wt+1=Wt+Vt+1 
Vt+1=μVtαL(Wt)

其中,Wt+1為第t+1輪的權值;Vt+1為第t+1輪的更新(也可以寫作ΔWt+1);μ為上一輪更新的權重;α為學習率;L(Wt)為loss對權值的求導

2 程式碼分析

2.1 ApplyUpdate

<code class="language-c++ hljs lasso has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;"><span class="hljs-literal" style="color: rgb(0, 102, 102); box-sizing: border-box;">void</span> SGDSolver<span class="hljs-subst" style="color: rgb(0, 0, 0); box-sizing: border-box;"><</span>Dtype<span class="hljs-subst" style="color: rgb(0, 0, 0); box-sizing: border-box;">></span><span class="hljs-tag" style="color: rgb(0, 102, 102); box-sizing: border-box;">::ApplyUpdate</span>() {
  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 獲取該輪迭代的學習率(learning rate)</span>
  Dtype rate <span class="hljs-subst" style="color: rgb(0, 0, 0); box-sizing: border-box;">=</span> GetLearningRate();

  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 對每一層網路的權值進行更新</span>
  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 在lenet中,只有`conv1`,`conv2`,`ip1`,`ip2`四層有引數</span>
  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 每層分別有引數與偏置引數兩項引數</span>
  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 因而`learnable_params_`的size為8.</span>
  for (int param_id <span class="hljs-subst" style="color: rgb(0, 0, 0); box-sizing: border-box;">=</span> <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>; param_id <span class="hljs-subst" style="color: rgb(0, 0, 0); box-sizing: border-box;"><</span> this<span class="hljs-subst" style="color: rgb(0, 0, 0); box-sizing: border-box;">-></span>net_<span class="hljs-subst" style="color: rgb(0, 0, 0); box-sizing: border-box;">-></span>learnable_params()<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">.</span>size();
       <span class="hljs-subst" style="color: rgb(0, 0, 0); box-sizing: border-box;">++</span>param_id) {
    <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 歸一化,iter_size為1不需要,因而lenet不需要</span>
    Normalize(param_id);
    <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 正則化</span>
    Regularize(param_id);
    <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 計算更新值\delta w</span>
    ComputeUpdateValue(param_id, rate);
  }
  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 更新權值</span>
  this<span class="hljs-subst" style="color: rgb(0, 0, 0); box-sizing: border-box;">-></span>net_<span class="hljs-subst" style="color: rgb(0, 0, 0); box-sizing: border-box;">-></span>Update();
}


</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22<