Caffe Softmax 層的實現原理【細節補充】
本文是看了知乎的這篇文章以後覺得作者寫的很好,但是有些細節講解得不夠詳細,回覆裡面大家也多有疑問,特加以補充:
為了對原作者表示尊重和感謝,先註明原作出處:
作者:John Wang
連結:https://www.zhihu.com/question/28927103/answer/78810153
作者原文和我的補充
====================================
設 z 是 softmax loss 層的輸入,f(z)是 softmax 的輸出,即
y 是輸入樣本 z 對應的類別,y=0,1,...,N
對於 z ,其損失函式定義為
展開上式:
對上式求導,有
梯度下降方向即為
====================================
增加關於 softmax 層的反向傳播說明
設 softmax 的輸出為 a ,輸入為 z ,損失函式為 l
則
其中
在 caffe 中是 top_diff,a 為 caffe 中得 top_data,需要計算的是
if i!=k
if i==k
【我的補充】
----------------------------------------------------------------
當 i!=k 時,
當 i==k 時,
----------------------------------------------------------------
於是
【我的補充】
----------------------------------------------------------------
把負號提出去,改為點乘,即得到上式。注意,這裡的 n 表示 channels,這裡的 k 和 caffe 原始碼中的 k 含義不同。
----------------------------------------------------------------
整理一下得到
其中表示將標量擴充套件為 n 維向量,表示向量按元素相乘
【我的補充】
----------------------------------------------------------------
這邊作者講解得有誤,因為對照程式碼可以發現,點乘後其實得到的是 1*inner_num 大小的向量,所以為了對應通道相減,需要將其擴充套件為 channels*inner_num 的矩陣,而不是 n 維向量。
最後矩陣再按元素進行相乘。
對照 caffe 原始碼
// top_diff : l 對 a 向量求偏導
// top_data :a 向量
// 將 top_diff 拷貝到 bottom_diff
// dim = channels * inner_num_
// inner_num_ = height * width
caffe_copy(top[0]->count(), top_diff, bottom_diff);
// 遍歷一個 batch 中的樣本
for (int i = 0; i < outer_num_; ++i) {
// compute dot(top_diff, top_data) and subtract them from the bottom diff
// 此處計算兩個向量的點積,注意 top_diff 已經拷貝到 bottom_diff 當中
// 步長為 inner_num_(跨通道)構造一個長度為 channels (類別個數)的向量,進行點乘
for (int k = 0; k < inner_num_; ++k) {
scale_data[k] = caffe_cpu_strided_dot<Dtype>(channels,
bottom_diff + i * dim + k, inner_num_,
top_data + i * dim + k, inner_num_);
}
// subtraction
// 此處計算大括號內的減法(即負號)
// 將 scale_data 擴充套件為 channels 個通道(多少個類別),再和 bottom_diff 對應的通道相減
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels, inner_num_, 1,
-1., sum_multiplier_.cpu_data(), scale_data, 1., bottom_diff + i * dim);
}
// elementwise multiplication
// 元素級的乘法
// 此處計算大括號外和 a 向量的乘法
caffe_mul(top[0]->count(), bottom_diff, top_data, bottom_diff)