1. 程式人生 > 實用技巧 >文字檢測網路Psenet學習(三)

文字檢測網路Psenet學習(三)

  現有的文字檢測方法主要有兩大類,一種是基於迴歸框的檢測方法(基於物體檢測的方法),如CTPN,EAST,這類方法很難檢測任意形狀的文字(曲線文字), 一種是基於畫素的分割檢測器(基於例項分割的方法),這類方法很難將彼此非常接近的文字例項分開。Psenet文字檢測方法是基於分割的方法,在2019年的論文Shape Robust Text Detection with Progressive Scale Expansion Network中提出,優化了近距離文字例項的分離。

  對於Psenet的學習,主要在於四方面:網路結構的設計,kernel的生成,漸進尺度擴充套件演算法(progressive scale expansion),loss函式

1. 網路結構的設計

  Psenet網路採用了resnet+fpn的架構,通過resnet提取特徵,取不同層的特徵送入fpn進行特徵融合,其結構如下圖所示:

  上圖中給出了訓練過程中網路資料流,總結如下:

  1. 1*3*640*640的圖片輸入網路,經過Resnet網路,將layer1,layer2,layer3,layer4的特徵圖p1(1*256*160*160), p2(1*512*80*80), p3(1*1024*40*40), p4(1*2048*20*20)送入fpn

  2. 以此對應p1, p2, p3, p4, fpn網路輸出特徵c1(1*256*160*160), c2(1*256*80*80), c3(1*256*40*40), c4(1*256*20*20)

  3. c2, c3, c4分別上取樣2,4,8倍後和c1進行concat得到特徵1*1024*160*160,再經過兩個卷積輸出1*7*160*160,上取樣4倍得到網路最終的輸出1*7*640*640。

  4.網路最後輸出了7個640*640的預測圖(map),分別表示預測的text_predict,和6個kernel_predict

  另外,上述採用resnet50的典型結構如下:

  

2. kernel的產生

  上面網路結構中提到模型最後輸出7個640*640的預測圖, 分別是預測的text,和6個kernel,因此在訓練時也需要通過標註資料產生7個640*640的map供網路學習,即text_gt和6個kernel_gt。其中text_gt就是一張二值圖,白色部分表示img中含有文字的區域,黑色部分表示背景區域,kernel_gt就是在text_gt的基礎上,將白色區域按一定的比例縮小。如下圖所示,根據r計算出d,表示該kernel的白色區域邊緣部分相對於text_gt的白色區域向內部移動了d個畫素。

3.漸進尺度擴充套件演算法(progressive scale expansion)

  在進行推理時,需要從網路輸出的6個kernel中得到需要的box,作者採用了pse(progressive scale exoansion)演算法。假設有kernel1,kernel2, kernel3, kernel4,kernel5,kernel6,先從文字區域最小的kernel6開始,遍歷其白色區域的畫素點,採用廣度優先法向四周擴充套件,依次合併kernel2,kernel3,kernel4,kernel5,kernel6, 最後合併得到一個kernel,整個合併演算法看程式碼比較好理解。取合併後kernel白色區域的矩形框或輪廓線即得到文字檢測框。論文中示意圖如下:

  參考python程式碼如下:

import numpy as np
import cv2
# import Queue
from queue import Queue

def pse(kernals, min_area):
    kernal_num = len(kernals)
    pred = np.zeros(kernals[0].shape, dtype='int32')
    
    label_num, label = cv2.connectedComponents(kernals[kernal_num - 1], connectivity=4)
    
    for label_idx in range(1, label_num):
        if np.sum(label == label_idx) < min_area:
            label[label == label_idx] = 0

    queue = Queue.Queue(maxsize = 0)
    next_queue = Queue.Queue(maxsize = 0)
    points = np.array(np.where(label > 0)).transpose((1, 0))
    
    for point_idx in range(points.shape[0]):
        x, y = points[point_idx, 0], points[point_idx, 1]
        l = label[x, y]
        queue.put((x, y, l))
        pred[x, y] = l

    dx = [-1, 1, 0, 0]
    dy = [0, 0, -1, 1]
    for kernal_idx in range(kernal_num - 2, -1, -1):
        kernal = kernals[kernal_idx].copy()
        while not queue.empty():
            (x, y, l) = queue.get()

            is_edge = True
            for j in range(4):
                tmpx = x + dx[j]
                tmpy = y + dy[j]
                if tmpx < 0 or tmpx >= kernal.shape[0] or tmpy < 0 or tmpy >= kernal.shape[1]:
                    continue
                if kernal[tmpx, tmpy] == 0 or pred[tmpx, tmpy] > 0:
                    continue

                queue.put((tmpx, tmpy, l))
                pred[tmpx, tmpy] = l
                is_edge = False
            if is_edge:
                next_queue.put((x, y, l))
        
        # kernal[pred > 0] = 0
        queue, next_queue = next_queue, queue
        
        # points = np.array(np.where(pred > 0)).transpose((1, 0))
        # for point_idx in range(points.shape[0]):
        #     x, y = points[point_idx, 0], points[point_idx, 1]
        #     l = pred[x, y]
        #     queue.put((x, y, l))

    return pred
pse演算法

4. loss函式理解

  psenet的loss包括兩部分,gt_text和kernel的loss,都採用dice loss計算損失值。總的loss計算如公司如下,權重係數一般取λ=0.7

  dice loss的計算公式如下,參見程式碼比較好理解

  dice loss 參考程式碼:

def dice_loss(input, target, mask):
    #input為預測的map
    #target為標註的map
    input = torch.sigmoid(input)

    input = input.contiguous().view(input.size()[0], -1)
    target = target.contiguous().view(target.size()[0], -1)
    mask = mask.contiguous().view(mask.size()[0], -1)

    input = input * mask
    target = target * mask

    a = torch.sum(input * target, 1)
    b = torch.sum(input * input, 1) + 0.001
    c = torch.sum(target * target, 1) + 0.001
    d = (2 * a) / (b + c)
    dice_loss = torch.mean(d)
    return 1 - dice_loss
dice loss示意程式碼

參考:

  https://github.com/whai362/PSENet

  https://github.com/WenmuZhou/PSENet.pytorch