1. 程式人生 > >lstm+ctc訓練端對端的模型(34)---《深度學習》

lstm+ctc訓練端對端的模型(34)---《深度學習》

我們在前面瞭解了CNN(卷積神經網路),也瞭解了RNN(遞迴神經網路),也在前面進行了基於CNN的BP和基於RNN的BPTT公式的推導,主要都利用了我們所定義的誤差δ(預期正確輸出和輸出之間的誤差),進行誤差的反向傳播,進而修改不同權重的梯度,然後是的網路朝著好的方向不斷訓練!
然而針對有的問題,這些模型卻並不適合,例如針對聲音轉文字的預測,不定長驗證碼的破解等等,CNN不適合因為其無法滿足時序特徵,RNN也不適合是因為雖然它滿足時序特徵,但是卻要求每一幀所對應的label進行訓練,因此也並不適合這種問題,因此針對這種問題:1)sequence to sequence(序列到序列問題));2)輸入序列的維度遠大於輸出序列。

我們需要找到一種新的方法,很舒服的是我們找到了,解決方法有兩種:1)RNN+CTC(一般採用lstm+ctc實現,因為RNN可能會出現梯度消失或者梯度爆炸的問題,由於lstm中進行了相應的處理,因此不會出現梯度消失或者梯度爆炸的問題);2)attention機制。本篇部落格我們主要講解lstm+ctc,attention機制以後有時間的話在進行一個講解吧!
ctc全稱為connectionist temporal classifier(連線時序分類器),主要用於len(輸入序列)>len(輸出序列)的這種問題,在這兒我們先來講4個概念:
這裡寫圖片描述
1)y(k,t):其中,k為下標,t為上標,因為CSDN暫時還不支援這種那種表達,所以暫時寫成這種,表示的是t時刻,y輸出的類別為k的可能性;
2)p(π|x):表示輸入為x的時候,預測的路徑為path(π)的概率,其中π為對應的輸出路徑,p(π|x)=∏(t=1..T) y(k,t),注意寫法哈,表示y(k,t)(t=1..T)的連乘,因為其每一個輸出都是無關的,因此可以寫成連乘;
3)p(l|x):表示輸入為x的時候,預測為標籤為lable的概率,其中l(標籤序列)一般和多個path(π)對應,如(aa-a–bb-)和(a–aaab-)和(a-aaa–b-)這幾個不同的路徑對應的卻是一個對應的lable序列(aab),因此p(l|x)=∑p(π|x),其中∑表明多條路徑π到同一個label標籤序列的對映;
下面如何如何從所有可能的標籤序列中挑選出一個最可能的標籤序列呢,即對其進行解碼呢?如果想要群舉的話可能真得類似,因為數量級太大太大,因此,我們需要找到一個合適的演算法進行解碼。
解碼:I(x):l(x)=argmax (l) p(l|x),表明從所所有的lable序列中選出概率最大的那個lable序列作為輸出。

Alex的論文中給了兩種很不錯的方法:
1)最佳路徑解碼(best path decoding):
h(x)≈B(π*)
where π=argmaxp(π|x),π∈N^t,這種解碼方式非常容易計算,π就是每個時間步驟的最可能的輸出連線,但這種方法的最大不足就在於它並不能保證最有可能的標記,因為它每次都是區域性最優,可能忽略掉全域性最優!
2)字首查詢解碼(prefix search decoding):
這裡寫圖片描述
非常適合解決輸入序列和輸出序列的長度相差很大的情況,希望大家可以學的開心哈!
The CTC Forward-Backward Algorithm
p(l|x)怎麼求解是個問題,因為對應的所有的路徑非常多,有T!種選擇,顯然我們如果想要求出所有的路徑不太現實,因為這些路徑中包括很多種不適合的,因此我們並沒有求出所有的路徑,我們選擇是ctc前向後向演算法,如下所示:
這裡寫圖片描述

這種方法選擇在所有的label兩兩之間加上了空格,並且在開頭和結尾之間也加上了空格,假如label的長度為l,則新的l’的長度為2l+1,然後我們從其中選擇最可能的標籤序列,即求出不同的p(l|x),最後再利用我們上面所介紹的解碼方式求出對應的標籤,這裡介紹的比較粗泛一些,詳細的大家可以看一下底下的部落格,講的非常細緻,對最大似然的loss函式也介紹的很仔細,包括它的偏導等等都很詳細!
這是一個使用lstm+ctc針對不定長(3-5)個變長驗證碼進行端對端識別的程式,希望可以幫助大家理解lstm+ctc的原理!https://github.com/Tangzixia/tensorflow_captcha_lstm_ctc_loss
如果大家對與ctc的loss函式的有興趣的話,可以參考這些部落格,資料很詳細,希望大家可以獲益匪淺!
CTC學習筆記(二) 訓練和公式推導
論文筆記:Connectionist Temporal Classification: Labelling Unsegmented Sequence
[論文]CTC——Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks
祝大家排雷順利!