1. 程式人生 > >移動端unet人像分割模型--3

移動端unet人像分割模型--3

  前兩篇文章已經完成基本從mxnet到ncnn的unet模型訓練和轉換。不過還存在幾個問題,1. 模型比較大,2. 單幀處理需要15秒左右的時間(MAC PRO,ncnn沒有使用openmp的情況),3. 得到的mask結果不是特別理想。針對這三個問題,本文將對網路結構進行調整。

  1. 模型比較大

  採取將網路卷積核數量減少4倍的方式,模型大小下降到2M,粗略用圖片測試,效果也還可以。為了提高準確率,採取將樣本翻轉、crop、旋轉等方式進行擴充。同時把之前用0值填充圖片的方式,改成用邊界值填充,因為測試的時候發現之前的方式總在填充的邊界往往會出現檢測錯誤。原先還做過試驗,如果2M還不夠小,可以把U型下降段改成mobilenet的方式進一步壓縮模型大小。

#!/usr/bin/env python
# coding=utf8

import os
import sys
import random
import cv2
import mxnet as mx
import numpy as np
from mxnet.io import DataIter, DataBatch

sys.path.append('../')

def padding_and_resize(img, dstwidth, dstheight):
    height = img.shape[0]
    width = img.shape[1]

    top = 0
    bottom = 0
    left = 0
    right = 0

    if width > height:
        top = int((width - height) / 2)
        bottom = int((width - height) - top)
    else:
        left = int((height - width) / 2)
        right = int((height - width) - left)

    tmp = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_REPLICATE)

    return cv2.resize(img, (dstwidth, dstheight))

def rotate_image(image, angle):
    # grab the dimensions of the image and then determine the
    # center
    (h, w) = image.shape[:2]
    (cX, cY) = (w // 2, h // 2)

    # grab the rotation matrix (applying the negative of the
    # angle to rotate clockwise), then grab the sine and cosine
    # (i.e., the rotation components of the matrix)
    M = cv2.getRotationMatrix2D((cX, cY), -angle, 1.0)
    cos = np.abs(M[0, 0])
    sin = np.abs(M[0, 1])

    # compute the new bounding dimensions of the image
    nW = int((h * sin) + (w * cos))
    nH = int((h * cos) + (w * sin))

    # adjust the rotation matrix to take into account translation
    M[0, 2] += (nW / 2) - cX
    M[1, 2] += (nH / 2) - cY

    # perform the actual rotation and return the image
    return cv2.warpAffine(image, M, (nW, nH), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE)

def get_batch(items, root_path, nClasses, height, width):
    x = []
    y = []

    for item in items:
        flipped = False
        cropped = False
        rotated = False
        rotated_neg = False

        image_path = root_path + item.split(' ')[0]
        label_path = root_path + item.split(' ')[-1].strip()

        if image_path.find('_flipped.') >= 0:
          image_path = image_path.replace('_flipped', '')
          flipped = True
        elif image_path.find('_cropped.') >= 0:
          image_path = image_path.replace('_cropped', '')
          cropped = True
        elif image_path.find('_rotated.') >= 0:
          image_path = image_path.replace('_rotated', '')
          rotated = True
        elif image_path.find('_rotated_neg.') >= 0:
          image_path = image_path.replace('_rotated_neg', '')
          rotated_neg = True

        im = cv2.imread(image_path, 1)
        lim = cv2.imread(label_path, 1)

        if cropped:
          tmp_height = im.shape[0]
          im = im[:,tmp_height//5:tmp_height*4//5]
          tmp_height = lim.shape[0]
          lim = lim[:,tmp_height//5:tmp_height*4//5]

        if flipped:
          im = cv2.flip(im, 1)
          lim = cv2.flip(lim, 1)

        if rotated:
          im = rotate_image(im, 13)
          lim = rotate_image(lim, 13)

        if rotated_neg:
          im = rotate_image(im, -13)
          lim = rotate_image(lim, -13)

        im = padding_and_resize(im, width, height)
        lim = padding_and_resize(lim, width, height)

        im = np.float32(im) / 255.0

        lim = lim[:, :, 0]
        seg_labels = np.zeros((height, width, nClasses))

        for c in range(nClasses):
            seg_labels[:, :, c] = (lim == c).astype(int)

        seg_labels = np.reshape(seg_labels, (width * height, nClasses))

        x.append(im.transpose((2,0,1)))
        y.append(seg_labels.transpose((1,0)))

    return mx.nd.array(x), mx.nd.array(y)

  2. 單幀處理需要15秒左右的時間

  按照第一步處理之後,基本上一張圖片只要1秒鐘就處理完成,如果用上了openmp多開幾個執行緒,1秒應該可以處理好幾張。有個想法是,如果把神經網路的每一層搞成一個執行緒負責,用流水線的方式,也許可以做到實時處理視訊幀。

  3. 得到的mask結果不是特別理想

  通過擴充樣本,修改網路concat方式混合訓練,比如up6 = mx.sym.concat(*[trans_conv6, conv5], dim=1, name='concat6')的conv5換成第一次卷積(原先是第二次)的結果,訓練幾個epoch再換回原來的網路。

  附幾張效果圖