1. 程式人生 > >【OCR技術系列之三】大批量生成文字訓練集

【OCR技術系列之三】大批量生成文字訓練集

9.png false per store else value 隨機 %d alt

放假了,終於可以繼續可以靜下心寫一寫OCR方面的東西。上次談到文字的切割,今天打算總結一下我們怎麽得到用於訓練的文字數據集。如果是想訓練一個手寫體識別的模型,用一些前人收集好的手寫文字集就好了,比如中科院的這些數據集。但是如果我們只是想要訓練一個專門用於識別印刷漢字的模型,那麽我們就需要各種印刷字體的訓練集,那怎麽獲取呢?借助強大的圖像庫,自己生成就行了!

先捋一捋思路,生成文字集需要什麽步驟:

  1. 確定你要生成多少字體,生成一個記錄著漢字與label的對應表。
  2. 確定和收集需要用到的字體文件。
  3. 生成字體圖像,存儲在規定的目錄下。
  4. 適當的數據增強。

第三步的生成字體圖像最為重要,如果僅僅是生成很正規的文字,那麽用這個正規文字集去訓練模型,第一圖像數目有點少,第二模型泛化能力比較差,所以我們需要對字體圖像做大量的圖像處理工作,以增大我們的印刷體文字數據集。

我總結了一下,我們可以做的一些圖像增強工作有這些:

  1. 文字扭曲
  2. 背景噪聲(椒鹽)
  3. 文字位置(設置文字的中心點)
  4. 筆畫粘連(膨脹來模擬)
  5. 筆畫斷裂(腐蝕來模擬)
  6. 文字傾斜(文字旋轉)
  7. 多種字體

做完以上增強後,我們得到的數據集已經非常龐大了。

現在開始一步一步生成我們的3755個漢字的印刷體文字數據集。

一、生成漢字與label的對應表

這裏的漢字、label映射表的生成我使用了pickel模塊,借助它生成一個id:漢字的映射文件存儲下來。
這裏舉個小例子說明怎麽生成這個“漢字:id”映射表。

首先在一個txt文件裏寫入你想要的漢字,如果對漢字對應的ID沒有要求的話,我們不妨使用該漢字的排位作為其ID,比如“一二三四五”中,五的ID就是00005。如此類推,把漢字讀入內存,建立一個字典,把這個關系記錄下來,再使用pickle.dump存入文件保存。

二、收集字體文件

字體文件上網收集就好了,但是值得註意的是,不是每一種字體都支持漢字,所以我們需要篩選出真正適合漢字生成的字體文件才可以。我一共使用了十三種漢字字體作為我們接下來漢字數據集用到的字體,具體如下圖:

技術分享圖片

當然,如果需要進一步擴大數據集來增強訓練得到的模型的泛化能力,可以花更多的時間去收集各類漢字字體,那麽模型在面對各種字體時也能從容應對,給出準確的預測。

三、文字圖像生成

首先是定義好輸入參數,其中包括輸出目錄、字體目錄、測試集大小、圖像尺寸、圖像旋轉幅度等等。

def args_parse():
    #解析輸入參數
    parser = argparse.ArgumentParser(
        description=description, formatter_class=RawTextHelpFormatter)
    parser.add_argument(‘--out_dir‘, dest=‘out_dir‘,
                        default=None, required=True,
                        help=‘write a caffe dir‘)
    parser.add_argument(‘--font_dir‘, dest=‘font_dir‘,
                        default=None, required=True,
                        help=‘font dir to to produce images‘)
    parser.add_argument(‘--test_ratio‘, dest=‘test_ratio‘,
                        default=0.2, required=False,
                        help=‘test dataset size‘)
    parser.add_argument(‘--width‘, dest=‘width‘,
                        default=None, required=True,
                        help=‘width‘)
    parser.add_argument(‘--height‘, dest=‘height‘,
                        default=None, required=True,
                        help=‘height‘)
    parser.add_argument(‘--no_crop‘, dest=‘no_crop‘,
                        default=True, required=False,
                        help=‘‘, action=‘store_true‘)
    parser.add_argument(‘--margin‘, dest=‘margin‘,
                        default=0, required=False,
                        help=‘‘, )
    parser.add_argument(‘--rotate‘, dest=‘rotate‘,
                        default=0, required=False,
                        help=‘max rotate degree 0-45‘)
    parser.add_argument(‘--rotate_step‘, dest=‘rotate_step‘,
                        default=0, required=False,
                        help=‘rotate step for the rotate angle‘)
    parser.add_argument(‘--need_aug‘, dest=‘need_aug‘,
                        default=False, required=False,
                        help=‘need data augmentation‘, action=‘store_true‘)   
    args = vars(parser.parse_args()) 
    return args

接下來需要將我們第一步得到的對應表讀入內存,因為這個表示ID到漢字的映射,我們在做一下轉換,改成漢字到ID的映射,用於後面的字體生成。

#將漢字的label讀入,得到(ID:漢字)的映射表label_dict
label_dict = get_label_dict()

char_list=[]  # 漢字列表
value_list=[] # label列表
for (value,chars) in label_dict.items():
    print (value,chars)
    char_list.append(chars)
    value_list.append(value)

# 合並成新的映射關系表:(漢字:ID)
lang_chars = dict(zip(char_list,value_list)) 
font_check = FontCheck(lang_chars) 

我們對旋轉的角度存儲到列表中,旋轉角度的範圍是[-rotate,rotate].

if rotate < 0:
    roate = - rotate

if rotate > 0 and rotate <= 45:
    all_rotate_angles = []
    for i in range(0, rotate+1, rotate_step):  
        all_rotate_angles.append(i)
    for i in range(-rotate, 0, rotate_step):
        all_rotate_angles.append(i)
    #print(all_rotate_angles)

現在說一下字體圖像是怎麽生成的,首先我們使用的工具是PIL。PIL裏面有很好用的漢字生成函數,我們用這個函數再結合我們提供的字體文件,就可以生成我們想要的數字化的漢字了。我們先設定好我們生成的字體顏色為黑底白色,字體尺寸由輸入參數來動態設定。

技術分享圖片

# 生成字體圖像
class Font2Image(object):

    def __init__(self,
                 width, height,
                 need_crop, margin):
        self.width = width
        self.height = height
        self.need_crop = need_crop
        self.margin = margin

    def do(self, font_path, char, rotate=0):
        find_image_bbox = FindImageBBox()
        # 黑色背景
        img = Image.new("RGB", (self.width, self.height), "black")
        draw = ImageDraw.Draw(img)
        font = ImageFont.truetype(font_path, int(self.width * 0.7),)
        # 白色字體
        draw.text((0, 0), char, (255, 255, 255),
                  font=font)
        if rotate != 0:
            img = img.rotate(rotate)
        data = list(img.getdata())
        sum_val = 0
        for i_data in data:
            sum_val += sum(i_data)
        if sum_val > 2:
            np_img = np.asarray(data, dtype=‘uint8‘)
            np_img = np_img[:, 0]
            np_img = np_img.reshape((self.height, self.width))
            cropped_box = find_image_bbox.do(np_img)
            left, upper, right, lower = cropped_box
            np_img = np_img[upper: lower + 1, left: right + 1]
            if not self.need_crop:
                preprocess_resize_keep_ratio_fill_bg =                     PreprocessResizeKeepRatioFillBG(self.width, self.height,
                                                    fill_bg=False,
                                                    margin=self.margin)
                np_img = preprocess_resize_keep_ratio_fill_bg.do(
                    np_img)
            # cv2.imwrite(path_img, np_img)
            return np_img
        else:
            print("img doesn‘t exist.")

我們寫兩個循環,外層循環是漢字列表,內層循環是字體列表,對於每個漢字會得到一個image_list列表,裏面存儲著這個漢字的所有圖像。

for (char, value) in lang_chars.items():  # 外層循環是字
    image_list = []
    print (char,value)
    #char_dir = os.path.join(images_dir, "%0.5d" % value)
    for j, verified_font_path in enumerate(verified_font_paths):    # 內層循環是字體   
        if rotate == 0:
            image = font2image.do(verified_font_path, char)
            image_list.append(image)
        else:
            for k in all_rotate_angles: 
                image = font2image.do(verified_font_path, char, rotate=k)
                image_list.append(image)

我們將image_list中圖像按照比例分為訓練集和測試集存儲。

        test_num = len(image_list) * test_ratio
        random.shuffle(image_list)  # 圖像列表打亂
        count = 0
        for i in range(len(image_list)):
            img = image_list[i]
            #print(img.shape)
            if count < test_num :
                char_dir = os.path.join(test_images_dir, "%0.5d" % value)
            else:
                char_dir = os.path.join(train_images_dir, "%0.5d" % value)

            if not os.path.isdir(char_dir):
                os.makedirs(char_dir)

            path_image = os.path.join(char_dir,"%d.png" % count)
            cv2.imwrite(path_image,img)
            count += 1

寫好代碼後,我們執行如下指令,開始生成印刷體文字漢字集。

 python gen_printed_char.py --out_dir ./dataset --font_dir ./chinese_fonts --width 30 --height 30 --margin 4 --rotate 30 --rotate_step 1

解析一下上述指令的附屬參數:

  1. --out_dir 表示生成的漢字圖像的存儲目錄
  2. --font_dir 表示放置漢字字體文件的路徑
  3. --width --height 表示生成圖像的高度和寬度
  4. --margin 表示字體與邊緣的間隔
  5. --rotate 表示字體旋轉的範圍,[-rotate,rotate]
  6. --rotate_step 表示每次旋轉的間隔

生成這麽一個3755個漢字的數據集的所需的時間還是很久的,估計接近一個小時。其實這個生成過程可以用多線程、多進程並行加速,但是考慮到這種文字數據集只需生成一次就好,所以就沒做這方面的優化了。數據集生成完我們可以發現,在dataset文件夾下得到train和test兩個文件夾,train和test文件夾下都有3755個子文件夾,分別存儲著生成的3755個漢字對應的圖像,每個子文件的名字就是該漢字對應的id。隨便選擇一個train文件夾下的一個子文件夾打開,可以看到所獲得的漢字圖像,一共634個。

dataset下自動生成測試集和訓練集

技術分享圖片

測試集和訓練集下都有3755個子文件夾,用於存儲每個漢字的圖像。

技術分享圖片

生成出來的漢字圖像

技術分享圖片

額外的圖像增強

第三步生成的漢字圖像是最基本的數據集,它所做的圖像處理僅有旋轉這麽一項,如果我們想在數據增強上再做多點東西,想必我們最終訓練出來的OCR模型的性能會更加優秀。我們使用opencv來完成我們定制的漢字圖像增強任務。

因為生成的圖像比較小,僅僅是30*30,如果對這麽小的圖像加噪聲或者形態學處理,得到的字體圖像會很糟糕,所以我們在做數據增強時,把圖片尺寸適當增加,比如設置為100×100,再進行相應的數據增強,效果會更好。

噪點增加

def add_noise(cls,img):
    for i in range(20): #添加點噪聲
        temp_x = np.random.randint(0,img.shape[0])
        temp_y = np.random.randint(0,img.shape[1])
        img[temp_x][temp_y] = 255
    return img

適當腐蝕

def add_erode(cls,img):
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(3, 3))    
    img = cv2.erode(img,kernel) 
    return img

適當膨脹

def add_dilate(cls,img):
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(3, 3))    
    img = cv2.dilate(img,kernel) 
    return img

然後做隨機擾動

def do(self,img_list=[]):
    aug_list= copy.deepcopy(img_list)
    for i in range(len(img_list)):
        im = img_list[i]
        if self.noise and random.random()<0.5:
            im = self.add_noise(im)
        if self.dilate and random.random()<0.25:
            im = self.add_dilate(im)
        if self.erode and random.random()<0.25:
            im = self.add_erode(im)    
        aug_list.append(im)
    return aug_list

輸入指令

python gen_printed_char.py --out_dir ./dataset2 --font_dir ./chinese_fonts --width 100 --height 100 --margin 10 --rotate 30 --rotate_step 1 --need_aug

使用這種生成的圖像如下圖所示,第一數據集擴大了兩倍,第二圖像的豐富性進一步提高,效果還是明顯的。當然,如果要獲得最好的效果,還需要調一下裏面的參數,這裏就不再詳細說明了。

技術分享圖片

至此,我們所需的印刷體漢字數據集已經成功生成完畢,下一步要做的就是利用這些數據集設計一個卷積神經網絡做文字識別了!完整的代碼可以在我的github獲取。

【OCR技術系列之三】大批量生成文字訓練集