1. 程式人生 > >CTC(Connectionist Temporal Classification)介紹

CTC(Connectionist Temporal Classification)介紹

CTC解決什麼問題

CTC,Connectionist Temporal Classification,用來解決輸入序列和輸出序列難以一一對應的問題。

舉例來說,在語音識別中,我們希望音訊中的音素和翻譯後的字元可以一一對應,這是訓練時一個很天然的想法。但是要對齊是一件很困難的事,如下圖所示(圖源見參考資料[1]),有人說話塊,有人說話慢,每個人說話快慢不同,不可能手動地對音素和字元對齊,這樣太耗時。

再比如,在OCR中使用RNN時,RNN的每一個輸出要對應到字元影象中的每一個位置,要手工做這樣的標記工作量太大,而且影象中的字元數量不同,字型樣式不同,大小不同,導致輸出不一定能和每個字元一一對應。

CTC基本概述

考慮一個LSTM,用w表示LSTM的引數,則LSTM可以表示為一個函式:\(y = N_w(x)\)

定義輸入x的時間步為T,每個時間步上的特徵維度記作m,表示m維特徵。 \[ x = (x^1, x^2, ..., x^T) \\x^t = (x_1^t, x_2^t, ..., x_m^t)\]

輸出時間步也為T,和輸入可以一一對應,每個時間步的輸出維度作為n,表示n維輸出,實際上是n個概率。 \[ y = (y^1, y^2, ..., y^T) \\y^t = (y_1^t, y_2^t, ..., y_n^t)\]

假設要對26個英文字元進行識別,考慮到有些位置沒有字元,定義一個-作為空白符加入到字元集合\({L}'= \{a,b,c,...,x,y,z\} \cup \{-\}= L \cup \{-\} = \{a,b,c,...,x,y,z,-\}\)

,那麼對於LSTM而言每個時間步的輸出維度n就是27,表示27個字元在這個時間步上輸出的概率。

如果根據這些概率進行選取,每個時間步選取一個元素,就可以得到輸出序列,其輸出空間可以記為\({L}'^T\)

定義一個B變換,對LSTM的輸出序列(比如下例中的4個\(\pi\))進行變換,變換成真實輸出(比如下例中的state),把連續的相同字元刪減為1個並刪去空白符。舉例說明,當T=12時: \[\begin{split} B(\pi^1) &= B(--stta-t---e) = state \\ B(\pi^2) &= B(sst-aaa-tee-) = state \\ B(\pi^3) &= B(--sttaa-tee-) = state \\ B(\pi^4) &= B(sst-aa-t---e) = state \end{split}\]

其中\(\pi\)表示LSTM的一種輸出序列。當我們優化LSTM時,只需要最大化以下概率,即給定輸入x的情況下,輸出為l的概率,\(l\)表示真實輸出。對下式取負號,就可以使用梯度下降對其求最小。 \[p(l \ | \ x) = \sum_{B(\pi)=l} p(\pi | x)\]

假設時間步之間的輸出獨立,那麼對於任意一個輸出序列\(\pi\)的概率計算式子如下, \[p(\pi | x) = \prod_{t=1}^{T} y_{\pi_t}^t\]

其中下標\(\pi_t\)表示的是,輸出序列在t時間步選取的元素對應的索引,比如該序列在第一個時間步選取的元素是a,那麼得到的值就是1。選取的是z,那麼得到的值就是26。選取的是空白符,那麼得到的值就是27。為了方便觀測,也用對應的字元表示,其實是一個意思,如下式所示。 \[ \pi = (--stta-t---e) \\ p(\pi | x) = {y^1}_{-} \cdot {y^2}_{-} \cdot y^3_{s} \cdot y^4_{t} \cdot y^5_{t} \cdot y^6_{a} \cdot {y^7}_{-} \cdot {y^8}_{t} \cdot {y^9}_{-} \cdot {y^{10}}_{-} \cdot {y^{11}}_{-} \cdot y^{12}_{e} \]

但是對於某一個真實輸出,比如上述的state,有多個LSTM的輸出序列可以通過B轉換得到。這些序列都是我們要的結果,我們要使給定x,這些輸出序列的概率加起來最大。如果逐條遍歷來求得,時間複雜度是指數級的,因為有T個位置,每個位置有n種選擇(字元集合的大小),那麼就有\(n^T\)種可能。因此CTC借用了HMM中的“前向-後向演算法”(forward-backward algorithm)來計算。

CTC中的前向後向演算法

由於真實輸出\(l\)是一個序列,序列可以通過一個路徑圖中的一條路徑來表示,我們也稱輸出序列\(l\)為路徑\(l\)。定義路徑\({l}'\)為“在路徑\(l\)每兩個元素之間以及頭尾插入空白符”,如: \[l = state \\ {l}' = -s-t-a-t-e-\]

對某個時間步的某個字元求導(這裡用k表示字元集合中的某個字元或字元索引),恰好是與概率\(y_k^t\)相關的路徑。 \[ \frac{\partial \, p(l \ |x)}{\partial \, y_k^t} = \frac{\partial \, \sum_{B(\pi)=l, \pi_t=k} p(\pi | x)}{\partial \, y_k^t} \]

以前面的\(\pi^1, \pi^2, \pi^3, \pi^4\)為例子,畫出兩條路徑(還有兩條沒畫出來),如下圖所示(圖源見參考資料[1])。

4條路徑都在t=6時經過了字元a,觀察4條路徑,可以得到如下式子。 \[ \begin{split} \pi^1 &= b = b_{1:5} + a_6 + b_{7:12} \\ \pi^2 &= r = r_{1:5} + a_6 + r_{7:12} \\ \pi^3 &= b_{1:5} + a_6 + r_{7:12} \\ \pi^4 &= r_{1:5} + a_6 + b_{7:12} \end{split}\]

\[\begin{split} p( \pi^1, \pi^2, \pi^3, \pi^4 | x) &= {y^1}_{-} \cdot {y^2}_{-} \cdot {y^3}_{s} \cdot {y^4}_{t} \cdot {y^5}_{t} \cdot {y^6}_{a} \cdot {y^7}_{-} \cdot {y^8}_{t} \cdot {y^9}_{-} \cdot {y^{10}}_{-} \cdot {y^{11}}_{-} \cdot {y^{12}}_{e} \\ &+ {y^1}_{s} \cdot {y^2}_{s} \cdot {y^3}_{t} \cdot {y^4}_{-} \cdot {y^5}_{a} \cdot {y^6}_{a} \cdot {y^7}_{a} \cdot {y^8}_{-} \cdot {y^9}_{t} \cdot {y^{10}}_{e} \cdot {y^{11}}_{e} \cdot {y^{12}}_{-} \\ &+ {y^1}_{-} \cdot {y^2}_{-} \cdot {y^3}_{s} \cdot {y^4}_{t} \cdot {y^5}_{t} \cdot {y^6}_{a} \cdot {y^7}_{a} \cdot {y^8}_{-} \cdot {y^9}_{t} \cdot {y^{10}}_{e} \cdot {y^{11}}_{e} \cdot {y^{12}}_{-} \\ &+ {y^1}_{s} \cdot {y^2}_{s} \cdot {y^3}_{t} \cdot {y^4}_{-} \cdot {y^5}_{a} \cdot {y^6}_{a} \cdot {y^7}_{-} \cdot {y^8}_{t} \cdot {y^9}_{-} \cdot {y^{10}}_{-} \cdot {y^{11}}_{-} \cdot {y^{12}}_{e}\end{split}\]

令: \[\begin{split} forward &= p(b_{1:5} + r_{1:5} | x) = {y^1}_{-} \cdot {y^2}_{-} \cdot {y^3}_{s} \cdot {y^4}_{t} \cdot {y^5}_{t} + {y^1}_{s} \cdot {y^2}_{s} \cdot {y^3}_{t} \cdot {y^4}_{-} \cdot {y^5}_{a} \\ backward &= p(b_{7:12} + r_{7:12} | x) = {y^7}_{-} \cdot {y^8}_{t} \cdot {y^9}_{-} \cdot {y^{10}}_{-} \cdot {y^{11}}_{-} \cdot {y^{12}}_{e} + {y^7}_{a} \cdot {y^8}_{-} \cdot {y^9}_{t} \cdot {y^{10}}_{e} \cdot {y^{11}}_{e} \cdot {y^{12}}_{-} \end{split}\] 那麼可以做如下表示: \[p( \pi^1, \pi^2, \pi^3, \pi^4 | x) = forward \cdot y_a^t \cdot backward\] 上述的forward和backward只包含了4條路徑,如果推廣一下forward和backward的含義,考慮所有路徑,可做如下表示: \[\sum_{B(\pi)=l, \pi_6=a} p(\pi | x) = forward \cdot y_a^t \cdot backward\]

定義forward\(\alpha_t({l}'_k)\),表示t時刻經過\({l}'_k\)字元的路徑概率中1-t的概率之和,式子定義如下。 \[\alpha_t({l}'_k) = \sum_{B(\pi)=l, \ \pi_t = {l}'_k} \ \prod_{{t}'=1}^{t} {y_{\pi_{{t}'}}^{{t}'}}\]

t=1時,符號只能是空白符或\(l_1\),可以得到以下初始條件: \[\alpha_1(-) = {y^1}_{-} \\ \alpha_1(l_1) = y^1_{l_1} \\ \alpha_1(l_t) = 0, \forall t > 1 \]

觀察上圖((圖源見參考資料[1])可以發現,如果t=6時字元是a,那麼t=5時只能是字元a,t,空白符三選一,否則經過B變換後無法得到state。 可以得到以下遞推關係: \[\alpha_6(a) = (\alpha_5(a) + \alpha_5(t) + \alpha_5(-)) \cdot y_a^6\]

更一般地,可以得到如下遞推關係: \[\alpha_t({l}'_k) = (\alpha_{t-1}({l}'_k) + \alpha_{t-1}({l}'_{k-1}) + \alpha_{t-1}(-)) \cdot y_{{l}'_k}^t\]

定義backward為為\(\beta_t(s)\),表示t時刻經過\({l}'_k\)字元的路徑概率中t-T的概率之和,式子定義如下。 \[\beta_t({l}'_k) = \sum_{B(\pi)=l, \ \pi_t = {l}'_k} \ \prod_{{t}'=t}^{T} {y_{\pi_{{t}'}}^{{t}'}}\]

t=T時,符號只能是空白符或\(l_{|{l}'|-1}\),可以得到以下初始條件: \[\beta_T(-) = {y^T}_- \\ \beta_T({l}'_{|{l}'|-1}) = y^T_{l_{|{l}'|-1}} \\ \beta_T(l_{|{l}'|-i}) = 0, \forall i > 1 \]

同理,可以得到如下遞推關係: \[\beta_t({l}'_k) = (\beta_{t+1}({l}'_k) + \beta_{t+1}({l}'_{k+1}) + \beta_{t+1}(-)) \cdot y_{{l}'_k}^t\]

根據forward和backward的式子定義,它們相乘可以得到: \[\alpha_t({l}'_k) \beta_t({l}'_k) = \sum_{B(\pi)=l, \ \pi_t = {l}'_k} \ y_{{l}'_k}^t \prod_{t=1}^{T} {y_{\pi_t}^t}\]

又因為\(p(l \ | \ x)\)\({l}'_k\)求導時,只跟那些\(\pi_t = {l}'_k\)的路徑有關,那麼求導時(注意是求導時)可以簡寫如下式子: \[p(l \ | \ x) = \sum_{B(\pi)=l, \ \pi_t = {l}'_k} p(\pi | x) = \sum_{B(\pi)=l, \ \pi_t = {l}'_k} \ \prod_{t=1}^{T} y_{\pi_t}^t\]

結合上面兩式,得到: \[p(l \ | \ x) = \sum_{B(\pi)=l, \ \pi_t = {l}'_k} \ \frac{\alpha_t({l}'_k) \beta_t({l}'_k) }{y_{{l}'_k}^t}\]

最後可以得到求導式(這裡用k來表示字元,和\({l}'_k\)的含義相同): \[\frac{\partial \, p(l \ |x)}{\partial \, y_k^t} = \frac{\partial \ \sum_{B(\pi)=l, \ \pi_t = k} \frac{\alpha_t(k) \beta_t(k) }{y_k^t}}{\partial \, y_k^t}\]

求導式裡的forward和backward可以用前面的dp遞推式計算出來,時間複雜度是\(nT\),相比於前面的指數複雜度\(n^T\)大大減小了計算量。

這樣對LSTM的輸出y求導之後,再根據y對LSTM裡面的權重引數w進行鏈式求導,就可以使用梯度下降的方法來更新引數了。

CTC的預測

一種方法是Best Path search。計算概率最大的一條輸出序列(假設時間步獨立,那麼直接在每個時間步取概率最大的字元輸出即可),但是這樣沒有考慮多個輸出序列對應一個真實輸出這件事,舉個例子,[s,s,-]和[s,s,s]的概率比[s,t,a]低,但是它們的概率之和會高於[s,t,a]。

第二種方法是Beam Search。假設指定B=3,預測過程如下圖所示(圖源見參考資料[2])。在第一個時間步選取概率最大的三個字元,然後在第二個時間步也選取概率最大的三個字元,兩兩組合(概率相乘)可以組合成9個序列,這些序列在B轉換之後會得到一些相同輸出,把具有相同輸出的序列進行合併,比如有3個序列都可以轉換成a,把它們合併(概率加在一起),計算出概率最大的三個序列,然後繼續和下一個時間步的字元進行同樣的合併。

有一點需要注意的是合併相同字元時,比如我們看上圖T=3的時候,第一個字首序列a,在跟相同字元a合併的時候,除了產生a之外,還會產生一個aa的有效輸出。這是因為這個字首序列a在T=2的時候曾經是把空白符合並掉了,實際上這個字首序列a後面是跟著一個空白符的,所以它在跟相同字元a合併的時候中間是有一個隱藏的空白符,合併之後得到的應該是兩個a。

因此在合併相同字元時,如果要合併成aa,需要統計在這之前以空白符結尾的那些序列的概率,如果要合併成a,計算的是不以空白符結尾的那些序列的概率。出於這個事實,我們需要跟蹤前兩處輸出,以便於後續的合併計算,見下圖所示(圖源見參考資料[2])。

CTC的幾個性質

第一個是條件獨立性。CTC做了一個假設就是不同時間步的輸出之間是獨立的。這個假設對於很多序列問題來說並不成立,輸出序列之間往往存在聯絡。

第二個是單調對齊。CTC只允許單調對齊,在語音識別中可能是有效的,但是在機器翻譯中,比如目標語句中的一些比較後的詞,可能與源語句中前面的一些詞對應,這個CTC是沒法做到的。

第三個是多對一對映。CTC的輸入和輸出是多對一的關係。這意味著輸出長度不能超過輸入長度,這在手寫字型識別或者語音中不是什麼問題,因為通常輸入都會大於輸出,但是對於輸出長度大於輸入長度的問題CTC就無法處理了。

參考資料

[2] Distill上一篇關於CTC的介紹(作者Hannun Awni):Sequence Modeling With CTC