1. 程式人生 > >lstm+ctc 實現ocr識別

lstm+ctc 實現ocr識別

轉載地址:

https://zhuanlan.zhihu.com/p/21344595

OCR是一個古老的研究領域,簡單說就是把圖片上的文字轉化為文字的過程。在最近幾年隨著大資料的發展,廣大爬蟲工程師在對抗驗證碼時也得用上OCR。所以,這篇文章主要說的OCR其實就是圖片驗證碼的識別。OCR並不是我的研究方向,我研究這個問題是因為OCR是一個可以同時用CNN,RNN兩種演算法都可以很好解決的問題,所以用這個問題來熟悉一個深度學習框架是非常適合的。我主要通過研究這個問題來了解mxnet

驗證碼識別的思路非常暴力,大概就是這樣:

  1. 去噪+二值化
  2. 字元分割
  3. 每個字元識別

驗證碼的難度在這3步上都有反應。比如

  1. 噪聲:加一條貫穿全圖的曲線,比如網格線,還有圖的一半是白底黑字,另一半是黑底白字。
  2. 分割:字元粘連,7和4粘在一起。
  3. 識別:字元各種扭曲,各種旋轉。

但相對而言,難度最大的是第2步,分割。所以就有人想,我能不能不做分割,就把驗證碼給識別了。深度學習擅長做端到端的學習,因此這個不分割就想識別的事情交給深度學習是最合適的。

基於CNN的驗證碼識別

基於CNN去識別驗證碼,其實就是一個圖片的多標籤學習問題。比如考慮一個4個數字組成的驗證碼,那麼相當於每張圖就有4個標籤。那麼我們把原始圖片作為輸入,4個標籤作為輸出,扔進CNN裡,看看能不能收斂就行了。

下面這段程式碼定義了mxnet上的一個DataIter,我們用了python-captcha這個庫來自動生成訓練樣本,所以可以假設訓練樣本是無窮多的。

class OCRIter(mx.io.DataIter):
def __init__(self, count, batch_size, num_label, height, width):
    super(OCRIter, self).__init__()
    self.captcha = ImageCaptcha(fonts=['./data/OpenSans-Regular.ttf'])
    self.batch_size = batch_size
    self.count = count
    self.height = height
    self.width =
width self.provide_data = [('data', (batch_size, 3, height, width))] self.provide_label = [('softmax_label', (self.batch_size, num_label))] def __iter__(self): for k in range(self.count / self.batch_size): data = [] label = [] for i in range(self.batch_size): # 生成一個四位數字的隨機字串 num = gen_rand() # 生成隨機字串對應的驗證碼圖片 img = self.captcha.generate(num) img = np.fromstring(img.getvalue(), dtype='uint8') img = cv2.imdecode(img, cv2.IMREAD_COLOR) img = cv2.resize(img, (self.width, self.height)) cv2.imwrite("./tmp" + str(i % 10) + ".png", img) img = np.multiply(img, 1/255.0) img = img.transpose(2, 0, 1) data.append(img) label.append(get_label(num)) data_all = [mx.nd.array(data)] label_all = [mx.nd.array(label)] data_names = ['data'] label_names = ['softmax_label'] data_batch = OCRBatch(data_names, data_all, label_names, label_all) yield data_batch def reset(self): pass

下面這段程式碼是網路結構:

def get_ocrnet():
    data = mx.symbol.Variable('data')
    label = mx.symbol.Variable('softmax_label')
    conv1 = mx.symbol.Convolution(data=data, kernel=(5,5), num_filter=32)
    pool1 = mx.symbol.Pooling(data=conv1, pool_type="max", kernel=(2,2), stride=(1, 1))
    relu1 = mx.symbol.Activation(data=pool1, act_type="relu")

    conv2 = mx.symbol.Convolution(data=relu1, kernel=(5,5), num_filter=32)
    pool2 = mx.symbol.Pooling(data=conv2, pool_type="avg", kernel=(2,2), stride=(1, 1))
    relu2 = mx.symbol.Activation(data=pool2, act_type="relu")

    conv3 = mx.symbol.Convolution(data=relu2, kernel=(3,3), num_filter=32)
    pool3 = mx.symbol.Pooling(data=conv3, pool_type="avg", kernel=(2,2), stride=(1, 1))
    relu3 = mx.symbol.Activation(data=pool3, act_type="relu")

    flatten = mx.symbol.Flatten(data = relu3)
    fc1 = mx.symbol.FullyConnected(data = flatten, num_hidden = 512)
    fc21 = mx.symbol.FullyConnected(data = fc1, num_hidden = 10)
    fc22 = mx.symbol.FullyConnected(data = fc1, num_hidden = 10)
    fc23 = mx.symbol.FullyConnected(data = fc1, num_hidden = 10)
    fc24 = mx.symbol.FullyConnected(data = fc1, num_hidden = 10)
    fc2 = mx.symbol.Concat(*[fc21, fc22, fc23, fc24], dim = 0)
    label = mx.symbol.transpose(data = label)
    label = mx.symbol.Reshape(data = label, target_shape = (0, ))
    return mx.symbol.SoftmaxOutput(data = fc2, label = label, name = "softmax")

上面這個網路要稍微解釋一下。因為這個問題是一個有順序的多label的圖片分類問題。我們在fc1的層上面接了4個Full Connect層(fc21,fc22,fc23,fc24),用來對應不同位置的4個數字label。然後將它們Concat在一起。然後同時學習這4個label。目前用上面的網路訓練,4位數字全部預測正確的精度可以達到95%左右(因為是無窮多的訓練樣本,所以只要能不斷訓練下去,精度還是可以提高的,只是我訓練到95%左右就停止訓練了)。

用CNN解決驗證碼識別有個問題,就是必須針對固定長度的驗證碼去做。如果長度不固定,或者是手寫一行字的識別這種長度肯定不固定的問題,CNN就沒辦法了。這個時候就需要引入序列學習的模型了。

基於LSTM+CTC的驗證碼識別

LSTM+CTC被廣泛的用在語音識別領域把音訊解碼成漢字,從這個角度說,OCR其實就是把圖片解碼成漢字,並沒有太本質的區別。而且在整個過程中,不需要提前知道究竟要解碼成幾個字。

這個演算法的思路是這樣的。假設要識別的圖片是80x30的圖片,裡面是一個長度為k的數字驗證碼。那麼我們可以沿著x軸對圖片進行切分,切成n個圖片,作為LSTM的n個輸入。在最極端的例子裡,n=80。那麼就是把圖片的每一列都作為輸入。LSTM有n個輸入就會有n個輸出,而這n個輸出可以通過CTC計算和k個驗證碼標籤之間的Loss,然後進行反向傳播。

我們同樣用python-captcha自動生成驗證碼作為訓練樣本,用如下的程式碼來定義網路結構:
def lstm_unroll(num_lstm_layer, seq_len,
                num_hidden, num_label):
    param_cells = []
    last_states = []
    for i in range(num_lstm_layer):
        state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i),
                          h=mx.sym.Variable("l%d_init_h" % i))
        last_states.append(state)
    assert(len(last_states) == num_lstm_layer)

    # embeding layer
    data = mx.sym.Variable('data')
    label = mx.sym.Variable('label')
    wordvec = mx.sym.SliceChannel(data=data, num_outputs=seq_len, squeeze_axis=1)

    hidden_all = []
    for seqidx in range(seq_len):
        hidden = wordvec[seqidx]
        for i in range(num_lstm_layer):
            next_state = lstm(num_hidden, indata=hidden,
                              prev_state=last_states[i],
                              param=param_cells[i],
                              seqidx=seqidx, layeridx=i)
            hidden = next_state.h
            last_states[i] = next_state
        hidden_all.append(hidden)

    hidden_concat = mx.sym.Concat(*hidden_all, dim=0)
    pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=11)

    label = mx.sym.Reshape(data=label, target_shape=(0,))
    label = mx.sym.Cast(data = label, dtype = 'int32')
    sm = mx.sym.WarpCTC(data=pred, label=label, label_length = num_label, input_length = seq_len)
    return sm

這裡有2點需要注意的:

  1. 在一般的mxnet的lstm實現中,label需要轉置,但是在warpctc的實現中不需要。
  2. label需要是int32的格式,需要cast。

關於CTC Loss的重要性,我試過不用CTC的兩個不同想法:

  1. 用encode-decode模式。用80個輸入做encode,然後decode成4個輸出。實測效果很差。
  2. 4個label每個copy20遍,從而變成80個label。實測也很差。

用ctc loss的體會就是,如果input的長度遠遠大於label的長度,比如我這裡是80和4的關係。那麼一開始的收斂會比較慢。在其中有一段時間cost幾乎不變。此刻一定要有耐心,最終一定會收斂的。在ocr識別的這個例子上最終可以收斂到95%的精度。