阿新 • • 發佈:2018-11-09
前兩篇文章已經完成基本從mxnet到ncnn的unet模型訓練和轉換。不過還存在幾個問題,1. 模型比較大,2. 單幀處理需要15秒左右的時間(MAC PRO,ncnn沒有使用openmp的情況),3. 得到的mask結果不是特別理想。針對這三個問題,本文將對網路結構進行調整。
1. 模型比較大
#!/usr/bin/env python # coding=utf8 import os import sys import random import cv2 import mxnet as mx import numpy as np from mxnet.io import DataIter, DataBatch sys.path.append('../') def padding_and_resize(img, dstwidth, dstheight): height = img.shape[0] width = img.shape[1] top = 0 bottom = 0 left = 0 right = 0 if width > height: top = int((width - height) / 2) bottom = int((width - height) - top) else: left = int((height - width) / 2) right = int((height - width) - left) tmp = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_REPLICATE) return cv2.resize(img, (dstwidth, dstheight)) def rotate_image(image, angle): # grab the dimensions of the image and then determine the # center (h, w) = image.shape[:2] (cX, cY) = (w // 2, h // 2) # grab the rotation matrix (applying the negative of the # angle to rotate clockwise), then grab the sine and cosine # (i.e., the rotation components of the matrix) M = cv2.getRotationMatrix2D((cX, cY), -angle, 1.0) cos = np.abs(M[0, 0]) sin = np.abs(M[0, 1]) # compute the new bounding dimensions of the image nW = int((h * sin) + (w * cos)) nH = int((h * cos) + (w * sin)) # adjust the rotation matrix to take into account translation M[0, 2] += (nW / 2) - cX M[1, 2] += (nH / 2) - cY # perform the actual rotation and return the image return cv2.warpAffine(image, M, (nW, nH), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE) def get_batch(items, root_path, nClasses, height, width): x = [] y = [] for item in items: flipped = False cropped = False rotated = False rotated_neg = False image_path = root_path + item.split(' ')[0] label_path = root_path + item.split(' ')[-1].strip() if image_path.find('_flipped.') >= 0: image_path = image_path.replace('_flipped', '') flipped = True elif image_path.find('_cropped.') >= 0: image_path = image_path.replace('_cropped', '') cropped = True elif image_path.find('_rotated.') >= 0: image_path = image_path.replace('_rotated', '') rotated = True elif image_path.find('_rotated_neg.') >= 0: image_path = image_path.replace('_rotated_neg', '') rotated_neg = True im = cv2.imread(image_path, 1) lim = cv2.imread(label_path, 1) if cropped: tmp_height = im.shape[0] im = im[:,tmp_height//5:tmp_height*4//5] tmp_height = lim.shape[0] lim = lim[:,tmp_height//5:tmp_height*4//5] if flipped: im = cv2.flip(im, 1) lim = cv2.flip(lim, 1) if rotated: im = rotate_image(im, 13) lim = rotate_image(lim, 13) if rotated_neg: im = rotate_image(im, -13) lim = rotate_image(lim, -13) im = padding_and_resize(im, width, height) lim = padding_and_resize(lim, width, height) im = np.float32(im) / 255.0 lim = lim[:, :, 0] seg_labels = np.zeros((height, width, nClasses)) for c in range(nClasses): seg_labels[:, :, c] = (lim == c).astype(int) seg_labels = np.reshape(seg_labels, (width * height, nClasses)) x.append(im.transpose((2,0,1))) y.append(seg_labels.transpose((1,0))) return mx.nd.array(x), mx.nd.array(y)
2. 單幀處理需要15秒左右的時間
3. 得到的mask結果不是特別理想
通過擴充樣本,修改網路concat方式混合訓練,比如up6 = mx.sym.concat(*[trans_conv6, conv5], dim=1, name='concat6')的conv5換成第一次卷積(原先是第二次)的結果,訓練幾個epoch再換回原來的網路。