【Learning Notes】CTC 原理及實現
CTC( Connectionist Temporal Classification,連線時序分類)是一種用於序列建模的工具,其核心是定義了特殊的目標函式/優化準則[1]。
jupyter notebook 版見 repo.
1. 演算法
這裡大體根據 Alex Graves 的開山之作[1],討論 CTC 的演算法原理,並基於 numpy 從零實現 CTC 的推理及訓練演算法。
1.1 序列問題形式化。
序列問題可以形式化為如下函式:
其中,序列目標為字串(詞表大小為 ),即 輸出為 維多項概率分佈(e.g. 經過 softmax 處理)。
網路輸出為:,其中, 表示時刻第 項的概率。
圖1. 序列建模【src】
雖然並沒為限定 具體形式,下面為假設其了某種神經網路(e.g. RNN)。
下面程式碼示例 toy :
import numpy as np
np.random.seed(1111)
T, V = 12, 5
m, n = 6, V
x = np.random.random([T, m]) # T x m
w = np.random.random([m, n]) # weights, m x n
def softmax(logits):
max_value = np.max(logits, axis=1, keepdims=True)
exp = np.exp(logits - max_value)
exp_sum = np.sum(exp, axis=1, keepdims=True)
dist = exp / exp_sum
return dist
def toy_nw(x):
y = np.matmul(x, w) # T x n
y = softmax(y)
return y
y = toy_nw(x)
print(y)
print(y.sum(1 , keepdims=True))
[[ 0.24654511 0.18837589 0.16937668 0.16757465 0.22812766]
[ 0.25443629 0.14992236 0.22945293 0.17240658 0.19378184]
[ 0.24134404 0.17179604 0.23572466 0.12994237 0.22119288]
[ 0.27216255 0.13054313 0.2679252 0.14184499 0.18752413]
[ 0.32558002 0.13485564 0.25228604 0.09743785 0.18984045]
[ 0.23855586 0.14800386 0.23100255 0.17158135 0.21085638]
[ 0.38534786 0.11524603 0.18220093 0.14617864 0.17102655]
[ 0.21867406 0.18511892 0.21305488 0.16472572 0.21842642]
[ 0.29856607 0.13646801 0.27196606 0.11562552 0.17737434]
[ 0.242347 0.14102063 0.21716951 0.2355229 0.16393996]
[ 0.26597326 0.10009752 0.23362892 0.24560198 0.15469832]
[ 0.23337289 0.11918746 0.28540761 0.20197928 0.16005275]]
[[ 1.]
[ 1.]
[ 1.]
[ 1.]
[ 1.]
[ 1.]
[ 1.]
[ 1.]
[ 1.]
[ 1.]
[ 1.]
[ 1.]]
1.2 align-free 變長對映
上面的形式是輸入和輸出的一對一的對映。序列學習任務一般而言是多對多的對映關係(如語音識別中,上百幀輸出可能僅對應若干音節或字元,並且每個輸入和輸出之間,也沒有清楚的對應關係)。CTC 通過引入一個特殊的 blank 字元(用 % 表示),解決多對一對映問題。
擴充套件原始詞表 為 。對輸出字串,定義操作 :1)合併連續的相同符號;2)去掉 blank 字元。
例如,對於 “aa%bb%%cc”,應用 ,則實際上代表的是字串 “abc”。同理“%a%b%cc%” 也同樣代表 “abc”。
通過引入blank 及 ,可以實現了變長的對映。
因為這個原因,CTC 只能建模輸出長度小於輸入長度的序列問題。
1.3 似然計算
和大多數有監督學習一樣,CTC 使用最大似然標準進行訓練。
給定輸入 ,輸出 的條件概率為:
其中, 表示了長度為 且示經過 結果為 字串的集合。
CTC 假設輸出的概率是(相對於輸入)條件獨立的,因此有:
然而,直接按上式我們沒有辦理有效的計算似然值。下面用動態規劃解決似然的計算及梯度計算, 涉及前向演算法和後向演算法。
1.4 前向演算法
在前向及後向計算中,CTC 需要將輸出字串進行擴充套件。具體的, 每個字元之間及首尾分別插入 blank,即擴充套件為 。下面的 為原始字串,