1. 程式人生 > >Python建立CRNN訓練用的LMDB資料庫檔案

Python建立CRNN訓練用的LMDB資料庫檔案

CRNN簡介

CRNN由 Baoguang Shi, Xiang Bai, Cong Yao提出,2015年7月發表論文:“An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition”,連結地址:https://arxiv.org/abs/1507.05717v1


CRNN(卷積迴圈神經網路)集成了卷積神經網路(CNN)和迴圈神經網路(RNN)的優點。CRNN可以直接從序列標籤(例如單詞,句子)中學習,不需要詳細的單個分別標註,並且對影象序列物件的長度無限定,只需要在訓練和測試階段對影象高度做一下歸一化。於現有技術相比,CRNN在場景文字識別上表現良好。

CRNN中訓練資料的格式是LMDB,儲存了兩種資料,一種是圖片資料,一種是標籤資料,它們各有其key,如下所示:



準備CRNN訓練資料集

資料集圖片是若干帶有文字的圖片,文字的高度約佔圖片高度的80%~90%,資料集標籤是txt文字格式,文字內容是圖片上的文字,文字名字要跟圖片名字一致,如123.jpg對應標籤需要是123.txt。

例如有 01.jpg 和 02.jpg 兩個樣本,標籤檔案是 01.txt 和 02.txt :




建立用於CRNN訓練的LMDB資料

# -*- coding: utf-8 -*-
import os
import lmdb # install lmdb by "pip install lmdb"
import cv2
import numpy as np
#from genLineText import GenTextImage

def checkImageIsValid(imageBin):
    if imageBin is None:
        return False
    imageBuf = np.fromstring(imageBin, dtype=np.uint8)
    img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
    if img is None:
        return False
    imgH, imgW = img.shape[0], img.shape[1]
    if imgH * imgW == 0:
        return False
    return True


def writeCache(env, cache):
    with env.begin(write=True) as txn:
        for k, v in cache.iteritems():
            txn.put(k, v)


def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True):
    """
    Create LMDB dataset for CRNN training.

    ARGS:
        outputPath    : LMDB output path
        imagePathList : list of image path
        labelList     : list of corresponding groundtruth texts
        lexiconList   : (optional) list of lexicon lists
        checkValid    : if true, check the validity of every image
    """
    #print (len(imagePathList) , len(labelList))
    assert(len(imagePathList) == len(labelList))
    nSamples = len(imagePathList)
    print '...................'
    # map_size=1099511627776 定義最大空間是1TB
    env = lmdb.open(outputPath, map_size=1099511627776)
    
    cache = {}
    cnt = 1
    for i in xrange(nSamples):
        imagePath = imagePathList[i]
        label = labelList[i]
        if not os.path.exists(imagePath):
            print('%s does not exist' % imagePath)
            continue
        with open(imagePath, 'r') as f:
            imageBin = f.read()
        if checkValid:
            if not checkImageIsValid(imageBin):
                print('%s is not a valid image' % imagePath)
                continue


        ########## .mdb資料庫檔案儲存了兩種資料,一種是圖片資料,一種是標籤資料,它們各有其key
        imageKey = 'image-%09d' % cnt
        labelKey = 'label-%09d' % cnt
        cache[imageKey] = imageBin
        cache[labelKey] = label
        ##########
        if lexiconList:
            lexiconKey = 'lexicon-%09d' % cnt
            cache[lexiconKey] = ' '.join(lexiconList[i])
        if cnt % 1000 == 0:
            writeCache(env, cache)
            cache = {}
            print('Written %d / %d' % (cnt, nSamples))
        cnt += 1
    nSamples = cnt-1
    cache['num-samples'] = str(nSamples)
    writeCache(env, cache)
    print('Created dataset with %d samples' % nSamples)


def read_text(path):
    
    with open(path) as f:
        text = f.read()
    text = text.strip()
    
    return text


import glob
if __name__ == '__main__':
    
    #lmdb 輸出目錄
    outputPath = '../data/lmdb/trainMy'

    # 訓練圖片路徑,標籤是txt格式,名字跟圖片名字要一致,如123.jpg對應標籤需要是123.txt
    path = '../data/dataline/*.jpg'

    imagePathList = glob.glob(path)
    print '------------',len(imagePathList),'------------'
    imgLabelLists = []
    for p in imagePathList:
        try:
           imgLabelLists.append((p,read_text(p.replace('.jpg','.txt'))))
        except:
            continue
            
    #imgLabelList = [ (p,read_text(p.replace('.jpg','.txt'))) for p in imagePathList]
    ##sort by lebelList
    imgLabelList = sorted(imgLabelLists,key = lambda x:len(x[1]))
    imgPaths = [ p[0] for p in imgLabelList]
    txtLists = [ p[1] for p in imgLabelList]
    
    createDataset(outputPath, imgPaths, txtLists, lexiconList=None, checkValid=True)

讀取LMDB資料集中圖片

# -*- coding: utf-8 -*-
import numpy as np
import lmdb
import cv2

with lmdb.open("../data/lmdb/train") as env:
    txn = env.begin()
    for key, value in txn.cursor():
        print (key,value)
        imageBuf = np.fromstring(value, dtype=np.uint8)
        img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
        if img is not None:
            cv2.imshow('image', img)
            cv2.waitKey()
        else:
            print 'This is a label: {}'.format(value)