1. 程式人生 > >多分類器softmax——絕對簡單易懂的梯度推導

多分類器softmax——絕對簡單易懂的梯度推導

損失函式的計算

首先說明本文解決的是softmax的多分類器的梯度求導,以下先給出損失函式的計算方式: 這裡將最終的loss分為4步進行計算,如下所示,當然,這裡不解釋為什麼是這樣的計算方式。 注意到,本文並不限制訓練樣本的數量,訓練樣本的特徵數,以及最後分為幾類。 公式(1) 這裡x表示輸入,w表示權重引數。 說明:這裡的x和w的下標表示x的某一行和w的某一列相乘在逐項相加得到s。 然後再根據s計算每一個類的概率,如下公式(2) 公式(2) 這裡採用的下標和公式(1)不相同,其中,n表示樣本的個數,y表示樣本為n時的正確分類標號。k表示有多少分類。這個公式就是先將s進行e次方計算,然後歸一化,求得該樣本正確分類下的概率p. 根據p可以計算出每一個樣本的損失,如公式(3): 公式(3)

這個公式說明,每一個樣本的損失僅僅是正確分類對應的概率值的log函式,這裡準確說應該是ln函式,也就是以自然對數為底的,這樣計算導數更方便,後面會以ln為版本進行計算。 最後,根據公式(4)計算所有樣本的損失: 公式(4) 也就是將所有樣本的損失求平均數。 注意:以上下標是獨立系統,與下面的推導過程沒有必然關係,這裡特別指ij,其他字母的含義基本相同。

基本求導法則

所謂梯度,就是求損失函式對引數w的導數,將其用在更新引數w上,達到優化的目的。 我們知道,梯度計算遵循著鏈式法則,而基本求導公式也是需要的,防止有人忘記,我先給出這裡將會用到的基本求導公式。知道的請跳過這一節,直接看下一節。 在這裡插入圖片描述 在這裡插入圖片描述 在這裡插入圖片描述 在這裡插入圖片描述 在這裡插入圖片描述 在這裡插入圖片描述

以下開始正式求梯度

計算整個損失函式對w(下標為ij)的導數。 根據鏈式法則,考慮到總損失為每個樣本損失的平均數,且每個樣本的損失都與wij相關,這個說明很有必要,假如某個損失與wij無關,我們就不用對它進行求導了。 有公式(5) 公式(5) 這裡Ln表示樣本為n時的損失函式。 不失一般性,這裡對最後一項進行繼續推導,然後將其相加。 同樣的,由於pny是和wij的函式,有公式(6): 公式(6) 結合公式(2),前一部分有有公式(7): 在這裡插入圖片描述 後一個部分我們繼續來考慮,pny的上下兩項是否都是wij的函式?肯定的回答是,這不一定,由公式(2)和(1)可知,如果公式2中分子的下標y不是j,那麼實際上這裡公式2的分子就不是wij的函式。 我們細說一下,由公式1,ij是公式1中的下標,當sij與wij有關係建立在這個j相等的情況,但是公式2的分子並不一定就滿足這個關係的,什麼情況滿足呢?那就是樣本n的正確分類的下標j和wij中的下標j相等時;否則這就沒有關係。

因此,我們需要分為兩種情況來進一步計算公式(6)的後半部分。 (實際上,我們也可以先認為他們相關,然後進一步處理,這裡我先不這麼做) 情況一,公式(2)中的分子與wij無關:也就是以下公式中y與j不相等 公式(2)中分母必然與wij有關,且只有一個與wij有關。那就是公式(2)中分母的下標k與wij的就相等時,而其他都與wij無關。 進一步考慮到e的s次方,s與wij的關係,因此針對情況一,有公式(8) 公式8 繼續對第二項展開有公式(9): 在這裡插入圖片描述 這裡還是細細說一下,這個過程,始終記住一點,那就是中間變數與wij是什麼關係,可以根據公式看出來。根據公式(1),當且僅當s的下標中是ij時才會與wij有關,而對sij對wij求導時得到的就是xii,(兩個i不一樣的含義)只需要把公式(1)中的x和w的下標中的點號換成i即可。也就是說,s對w求導時,x的第一個下標是s的第一個下標,x的第二個下標是w的第一個下標。當然,這裡我們需要再將s的下標i換成n,這樣才能滿足以上的推導。 我們將公式(9)根據公式(2)化簡一下,再帶入公式(6),可以得到公式(10),也就是情況一下的最終一個樣本的梯度: 公式(10) 其中,用了一個簡寫,也就是求和的項簡寫了,請留意。 寫成pnj是因為我們計算過程中會產生這個數,而且這樣寫起來也更整齊。 情況二,公式(2)中的分子是wij的函式: 注意到這裡,公式(2)中pny的下標y和wij的下標j是相等的,也就是y=j。 情況2比情況1複雜在公式(2)的分母上,其餘相同,因此,對其求導過程如下: 這裡先使用ynj(nj是下標)表示樣本為n時第j個分類的真實值,要麼是0,要麼是1,1表示真實分類就是這個j. 情況一根據(1\u)'求導,情況二則根據(v/u)'來求導,因此有一點差別。 以下一步一步的寫: 在這裡插入圖片描述 根據公式(2)將後面展開可得: 在這裡插入圖片描述 化簡一下可以得到: 在這裡插入圖片描述 根據公式(2)繼續化簡: 在這裡插入圖片描述 對上式去括號操作: 在這裡插入圖片描述 繼續求導並且根據公式2化簡得公式(11) 在這裡插入圖片描述 可以看出,這與上面的情況一相差在最後一項上,而前面一項是相等的。 接下來我們一起探討一下怎麼求後面的一項,畢竟這還無法完全理解清楚,因為這還是一個導數,也不是輸入或者中間求到的某個數。 前面我們已經說到,情況二下公式(2)中的y和wij的j是相等的。 這時候計算知道: 在這裡插入圖片描述 所以公式(11)進一步計算可得最終的求導公式:公式(12) 在這裡插入圖片描述

綜合兩個情況

情況二比情況一多減去一項。 一般情況下,我們直接使用pnj * xni即可。 而當wij中j是當前樣本n的正確分類時要多減去xni。

以上既是多分類器softmax的梯度求導公式。

後話

其實個人感覺梯度的計算還是挺難的,而且本文只是推導公式,還沒有真正的程式設計計算。 實際上,我們通常為了保證我們的程式正確,會寫一個數值求導,正確情況下兩者不會相差很多。 本文的理論推導,將會在下一篇部落格中寫明如何進行計算。