文字檢測網路Psenet學習(三)
現有的文字檢測方法主要有兩大類,一種是基於迴歸框的檢測方法(基於物體檢測的方法),如CTPN,EAST,這類方法很難檢測任意形狀的文字(曲線文字), 一種是基於畫素的分割檢測器(基於例項分割的方法),這類方法很難將彼此非常接近的文字例項分開。Psenet文字檢測方法是基於分割的方法,在2019年的論文Shape Robust Text Detection with Progressive Scale Expansion Network中提出,優化了近距離文字例項的分離。
對於Psenet的學習,主要在於四方面:網路結構的設計,kernel的生成,漸進尺度擴充套件演算法(progressive scale expansion),loss函式
1. 網路結構的設計
Psenet網路採用了resnet+fpn的架構,通過resnet提取特徵,取不同層的特徵送入fpn進行特徵融合,其結構如下圖所示:
上圖中給出了訓練過程中網路資料流,總結如下:
1. 1*3*640*640的圖片輸入網路,經過Resnet網路,將layer1,layer2,layer3,layer4的特徵圖p1(1*256*160*160), p2(1*512*80*80), p3(1*1024*40*40), p4(1*2048*20*20)送入fpn
2. 以此對應p1, p2, p3, p4, fpn網路輸出特徵c1(1*256*160*160), c2(1*256*80*80), c3(1*256*40*40), c4(1*256*20*20)
3. c2, c3, c4分別上取樣2,4,8倍後和c1進行concat得到特徵1*1024*160*160,再經過兩個卷積輸出1*7*160*160,上取樣4倍得到網路最終的輸出1*7*640*640。
4.網路最後輸出了7個640*640的預測圖(map),分別表示預測的text_predict,和6個kernel_predict
另外,上述採用resnet50的典型結構如下:
2. kernel的產生
上面網路結構中提到模型最後輸出7個640*640的預測圖, 分別是預測的text,和6個kernel,因此在訓練時也需要通過標註資料產生7個640*640的map供網路學習,即text_gt和6個kernel_gt。其中text_gt就是一張二值圖,白色部分表示img中含有文字的區域,黑色部分表示背景區域,kernel_gt就是在text_gt的基礎上,將白色區域按一定的比例縮小。如下圖所示,根據r計算出d,表示該kernel的白色區域邊緣部分相對於text_gt的白色區域向內部移動了d個畫素。
3.漸進尺度擴充套件演算法(progressive scale expansion)
在進行推理時,需要從網路輸出的6個kernel中得到需要的box,作者採用了pse(progressive scale exoansion)演算法。假設有kernel1,kernel2, kernel3, kernel4,kernel5,kernel6,先從文字區域最小的kernel6開始,遍歷其白色區域的畫素點,採用廣度優先法向四周擴充套件,依次合併kernel2,kernel3,kernel4,kernel5,kernel6, 最後合併得到一個kernel,整個合併演算法看程式碼比較好理解。取合併後kernel白色區域的矩形框或輪廓線即得到文字檢測框。論文中示意圖如下:
參考python程式碼如下:
import numpy as np import cv2 # import Queue from queue import Queue def pse(kernals, min_area): kernal_num = len(kernals) pred = np.zeros(kernals[0].shape, dtype='int32') label_num, label = cv2.connectedComponents(kernals[kernal_num - 1], connectivity=4) for label_idx in range(1, label_num): if np.sum(label == label_idx) < min_area: label[label == label_idx] = 0 queue = Queue.Queue(maxsize = 0) next_queue = Queue.Queue(maxsize = 0) points = np.array(np.where(label > 0)).transpose((1, 0)) for point_idx in range(points.shape[0]): x, y = points[point_idx, 0], points[point_idx, 1] l = label[x, y] queue.put((x, y, l)) pred[x, y] = l dx = [-1, 1, 0, 0] dy = [0, 0, -1, 1] for kernal_idx in range(kernal_num - 2, -1, -1): kernal = kernals[kernal_idx].copy() while not queue.empty(): (x, y, l) = queue.get() is_edge = True for j in range(4): tmpx = x + dx[j] tmpy = y + dy[j] if tmpx < 0 or tmpx >= kernal.shape[0] or tmpy < 0 or tmpy >= kernal.shape[1]: continue if kernal[tmpx, tmpy] == 0 or pred[tmpx, tmpy] > 0: continue queue.put((tmpx, tmpy, l)) pred[tmpx, tmpy] = l is_edge = False if is_edge: next_queue.put((x, y, l)) # kernal[pred > 0] = 0 queue, next_queue = next_queue, queue # points = np.array(np.where(pred > 0)).transpose((1, 0)) # for point_idx in range(points.shape[0]): # x, y = points[point_idx, 0], points[point_idx, 1] # l = pred[x, y] # queue.put((x, y, l)) return predpse演算法
4. loss函式理解
psenet的loss包括兩部分,gt_text和kernel的loss,都採用dice loss計算損失值。總的loss計算如公司如下,權重係數一般取λ=0.7
dice loss的計算公式如下,參見程式碼比較好理解
dice loss 參考程式碼:
def dice_loss(input, target, mask): #input為預測的map #target為標註的map input = torch.sigmoid(input) input = input.contiguous().view(input.size()[0], -1) target = target.contiguous().view(target.size()[0], -1) mask = mask.contiguous().view(mask.size()[0], -1) input = input * mask target = target * mask a = torch.sum(input * target, 1) b = torch.sum(input * input, 1) + 0.001 c = torch.sum(target * target, 1) + 0.001 d = (2 * a) / (b + c) dice_loss = torch.mean(d) return 1 - dice_lossdice loss示意程式碼
參考:
https://github.com/whai362/PSENet
https://github.com/WenmuZhou/PSENet.pytorch