kaldi中CD-DNN-HMM網路引數更新公式手寫推導
在基於DNN-HMM的語音識別中,DNN的作用跟GMM是一樣的,即它是取代GMM的,具體作用是算特徵值對每個三音素狀態的概率,算出來哪個最大這個特徵值就對應哪個狀態。只不過以前是用GMM算的,現在用DNN算了。這是典型的多分類問題,所以輸出層用的啟用函式是softmax,損失函式用的是cross entropy(交叉熵)。不用均方差做損失函式的原因是在分類問題上它是非凸函式,不能保證全域性最優解(只有凸函式才能保證全域性最優解)。Kaldi中也支援DNN-HMM,它還依賴於上下文(context dependent, CD),所以叫CD-DNN-HMM。在kaldi的nnet1中,特徵提取用filterbank,每幀40維資料,預設取當前幀前後5幀加上當前幀共11幀作為輸入,所以輸入層維數是440(440 = 40*11)。同時預設有4個隱藏層,每層1024個網元,啟用函式是sigmoid。今天我們看看網路的各種引數是怎麼得到的(手寫推導)。由於真正的網路比較複雜,為了推導方便這裡對其進行了簡化,只有一個隱藏層,每層的網元均為3,同時只有weight沒有bias。這樣網路如下圖:
上圖中輸入層3個網元為i1/i2/i3(i表示input),隱藏層3個網元為h1/h2/h3(h表示hidden),輸出層3個網元為o1/o2/o3(o表示output)。隱藏層h1的輸入為 (q11等表示輸入層和隱藏層之間的權值),輸出為 。輸出層o1的輸入為 (w11等表示隱藏層和輸出層之間的權值),輸出為 。其他可類似推出。損失函式用交叉熵。今天我們看看網路引數(以隱藏層和輸出層之間的w11以及輸入層和隱藏層之間的q11為例)在每次迭代訓練後是怎麼更新的。先看隱藏層和輸出層之間的w11。
1,隱藏層和輸出層之間的w11的更新
先分別求三個導數的值:
所以最終的w11更新公式如下圖:
2,輸入層和隱藏層之間的q11的更新
先分別求三個導數的值:
所以最終的q11更新公式如下圖:
以上的公式推導中如有錯誤,煩請指出,非常感