多分類器softmax——絕對簡單易懂的梯度推導
損失函式的計算
首先說明本文解決的是softmax的多分類器的梯度求導,以下先給出損失函式的計算方式: 這裡將最終的loss分為4步進行計算,如下所示,當然,這裡不解釋為什麼是這樣的計算方式。 注意到,本文並不限制訓練樣本的數量,訓練樣本的特徵數,以及最後分為幾類。 這裡x表示輸入,w表示權重引數。 說明:這裡的x和w的下標表示x的某一行和w的某一列相乘在逐項相加得到s。 然後再根據s計算每一個類的概率,如下公式(2) 這裡採用的下標和公式(1)不相同,其中,n表示樣本的個數,y表示樣本為n時的正確分類標號。k表示有多少分類。這個公式就是先將s進行e次方計算,然後歸一化,求得該樣本正確分類下的概率p. 根據p可以計算出每一個樣本的損失,如公式(3):
基本求導法則
所謂梯度,就是求損失函式對引數w的導數,將其用在更新引數w上,達到優化的目的。 我們知道,梯度計算遵循著鏈式法則,而基本求導公式也是需要的,防止有人忘記,我先給出這裡將會用到的基本求導公式。知道的請跳過這一節,直接看下一節。
以下開始正式求梯度
計算整個損失函式對w(下標為ij)的導數。 根據鏈式法則,考慮到總損失為每個樣本損失的平均數,且每個樣本的損失都與wij相關,這個說明很有必要,假如某個損失與wij無關,我們就不用對它進行求導了。 有公式(5) 這裡Ln表示樣本為n時的損失函式。 不失一般性,這裡對最後一項進行繼續推導,然後將其相加。 同樣的,由於pny是和wij的函式,有公式(6): 結合公式(2),前一部分有有公式(7): 後一個部分我們繼續來考慮,pny的上下兩項是否都是wij的函式?肯定的回答是,這不一定,由公式(2)和(1)可知,如果公式2中分子的下標y不是j,那麼實際上這裡公式2的分子就不是wij的函式。 我們細說一下,由公式1,ij是公式1中的下標,當sij與wij有關係建立在這個j相等的情況,但是公式2的分子並不一定就滿足這個關係的,什麼情況滿足呢?那就是樣本n的正確分類的下標j和wij中的下標j相等時;否則這就沒有關係。
綜合兩個情況
情況二比情況一多減去一項。 一般情況下,我們直接使用pnj * xni即可。 而當wij中j是當前樣本n的正確分類時要多減去xni。
以上既是多分類器softmax的梯度求導公式。
後話
其實個人感覺梯度的計算還是挺難的,而且本文只是推導公式,還沒有真正的程式設計計算。 實際上,我們通常為了保證我們的程式正確,會寫一個數值求導,正確情況下兩者不會相差很多。 本文的理論推導,將會在下一篇部落格中寫明如何進行計算。