1. 程式人生 > >【Learning Notes】CTC 原理及實現

【Learning Notes】CTC 原理及實現

CTC( Connectionist Temporal Classification,連線時序分類)是一種用於序列建模的工具,其核心是定義了特殊的目標函式/優化準則[1]。

jupyter notebook 版見 repo.

1. 演算法

這裡大體根據 Alex Graves 的開山之作[1],討論 CTC 的演算法原理,並基於 numpy 從零實現 CTC 的推理及訓練演算法。

1.1 序列問題形式化。

序列問題可以形式化為如下函式:

Nw:(Rm)T(Rn)T
其中,序列目標為字串(詞表大小為 n),即 Nw 輸出為 n
維多項概率分佈(e.g. 經過 softmax 處理)。

網路輸出為:y=Nw,其中,ykt t 表示時刻第 k 項的概率。


圖1. 序列建模【src

雖然並沒為限定 Nw 具體形式,下面為假設其了某種神經網路(e.g. RNN)。
下面程式碼示例 toy Nw

import numpy as np

np.random.seed(1111)

T, V = 12, 5
m, n = 6, V

x = np.random.random([T, m])  # T x m
w = np.random.random([m, n])  # weights, m x n
def softmax(logits): max_value = np.max(logits, axis=1, keepdims=True) exp = np.exp(logits - max_value) exp_sum = np.sum(exp, axis=1, keepdims=True) dist = exp / exp_sum return dist def toy_nw(x): y = np.matmul(x, w) # T x n y = softmax(y) return y y = toy_nw(x) print(y) print(y.sum(1
, keepdims=True))
[[ 0.24654511  0.18837589  0.16937668  0.16757465  0.22812766]
 [ 0.25443629  0.14992236  0.22945293  0.17240658  0.19378184]
 [ 0.24134404  0.17179604  0.23572466  0.12994237  0.22119288]
 [ 0.27216255  0.13054313  0.2679252   0.14184499  0.18752413]
 [ 0.32558002  0.13485564  0.25228604  0.09743785  0.18984045]
 [ 0.23855586  0.14800386  0.23100255  0.17158135  0.21085638]
 [ 0.38534786  0.11524603  0.18220093  0.14617864  0.17102655]
 [ 0.21867406  0.18511892  0.21305488  0.16472572  0.21842642]
 [ 0.29856607  0.13646801  0.27196606  0.11562552  0.17737434]
 [ 0.242347    0.14102063  0.21716951  0.2355229   0.16393996]
 [ 0.26597326  0.10009752  0.23362892  0.24560198  0.15469832]
 [ 0.23337289  0.11918746  0.28540761  0.20197928  0.16005275]]
[[ 1.]
 [ 1.]
 [ 1.]
 [ 1.]
 [ 1.]
 [ 1.]
 [ 1.]
 [ 1.]
 [ 1.]
 [ 1.]
 [ 1.]
 [ 1.]]

1.2 align-free 變長對映

上面的形式是輸入和輸出的一對一的對映。序列學習任務一般而言是多對多的對映關係(如語音識別中,上百幀輸出可能僅對應若干音節或字元,並且每個輸入和輸出之間,也沒有清楚的對應關係)。CTC 通過引入一個特殊的 blank 字元(用 % 表示),解決多對一對映問題。

擴充套件原始詞表 LL=L{blank}。對輸出字串,定義操作 B:1)合併連續的相同符號;2)去掉 blank 字元。

例如,對於 “aa%bb%%cc”,應用 B,則實際上代表的是字串 “abc”。同理“%a%b%cc%” 也同樣代表 “abc”。

B(aa%bb%%cc)=B(%a%b%cc%)=abc

通過引入blank 及 B,可以實現了變長的對映。

LTLT
因為這個原因,CTC 只能建模輸出長度小於輸入長度的序列問題。

1.3 似然計算

和大多數有監督學習一樣,CTC 使用最大似然標準進行訓練。

給定輸入 x,輸出 l 的條件概率為:

p(l|x)=πB1(l)p(π|x)

其中,B1(l) 表示了長度為 T 且示經過 B 結果為 l 字串的集合。

CTC 假設輸出的概率是(相對於輸入)條件獨立的,因此有:

p(π|x)=yπtt,πLT

然而,直接按上式我們沒有辦理有效的計算似然值。下面用動態規劃解決似然的計算及梯度計算, 涉及前向演算法和後向演算法。

1.4 前向演算法

在前向及後向計算中,CTC 需要將輸出字串進行擴充套件。具體的,(a1,,am) 每個字元之間及首尾分別插入 blank,即擴充套件為 (%,a1,%,a2,%,,%,am,%)。下面的 l 為原始字串,