移動端unet人像分割模型--1
個人對移動端神經網路開發一直饒有興致。去年騰訊開源了NCNN框架之後,一直都在關注。近期成功利用別人訓練好的mtcnn和mobilefacenet模型製作了一個ios版本人臉識別swift版本demo。希望maskrcnn移植到ncnn,在手機端實現一些有趣的應用。因為unet模型比較簡單,乾脆就從這個入手。
基本的網路基於keras版本: https://github.com/TianzhongSong/Person-Segmentation-Keras
不過keras沒辦法直接轉成ncnn模型,研究過通過onnx模型做中間跳板,採用了一些開源的轉換工具,也是一堆問題。NCNN支援幾個神經網路訓練框架:caffe/mxnet/pytorch,在ncnn的github有一篇issue裡nihui推薦採用mxnet,因此mxnet也成為了我的首選。
利用Person-Segmentation-Keras專案的資料集,同時基於https://github.com/milesial/Pytorch-UNet/tree/master/unet這個專案搗鼓了幾段程式碼。訓練完成,用來測試ncnn轉換基本可用。
轉換過程發現許多問題,一個是呼叫ncnn extract會crash,經過調查,發現mxnet2ncnn工具也有bug,blob個數算錯,其次是input層one_blob_only標誌我的理解應該是false,不知道什麼原因轉換過來的模型這邊是true,導致forward_layer函式裡面bottoms變數訪問異常。後來一層層extract出來列印輸出的channel/width/height調查後又發現,我把unet.py裡的name為pool5寫成了pool4(文章中的code已經糾正),可能前面的crash跟這個致命錯誤有關係也說不定。只好重新訓練模型,幾個小時漫長等待,剩下部分下週再寫。部分程式碼已經更新,請參考:
unetdataiter.py
#!/usr/bin/env python # coding=utf8 import os import sys import random import cv2 import mxnet as mx import numpy as np from mxnet.io import DataIter, DataBatch sys.path.append('../') def get_batch(items, root_path, nClasses, height, width): x = [] y = [] for item in items: image_path = root_path + item.split(' ')[0] label_path = root_path + item.split(' ')[-1].strip() img = cv2.imread(image_path, 1) label_img = cv2.imread(label_path, 1) im = np.zeros((width, height, 3), dtype='uint8') im[:, :, :] = 128 lim = np.zeros((width, height, 3), dtype='uint8') if img.shape[0] >= img.shape[1]: scale = img.shape[0] / height new_width = int(img.shape[1] / scale) diff = (width - new_width) // 2 img = cv2.resize(img, (new_width, height)) label_img = cv2.resize(label_img, (new_width, height)) im[:, diff:diff + new_width, :] = img lim[:, diff:diff + new_width, :] = label_img else: scale = img.shape[1] / width new_height = int(img.shape[0] / scale) diff = (height - new_height) // 2 img = cv2.resize(img, (width, new_height)) label_img = cv2.resize(label_img, (width, new_height)) im[diff:diff + new_height, :, :] = img lim[diff:diff + new_height, :, :] = label_img lim = lim[:, :, 0] seg_labels = np.zeros((height, width, nClasses)) for c in range(nClasses): seg_labels[:, :, c] = (lim == c).astype(int) im = np.float32(im) / 127.5 - 1 seg_labels = np.reshape(seg_labels, (width * height, nClasses)) x.append(im.transpose((2,0,1))) y.append(seg_labels.transpose((1,0))) return mx.nd.array(x), mx.nd.array(y) class UnetDataIter(mx.io.DataIter): def __init__(self, root_path, path_file, batch_size, n_classes, input_width, input_height, train=True): f = open(path_file, 'r') self.items = f.readlines() f.close() self._provide_data = [['data', (batch_size, 3, input_width, input_height)]] self._provide_label = [['softmax_label', (batch_size, n_classes, input_width*input_height)]] self.root_path = root_path self.batch_size = batch_size self.num_batches = len(self.items) // batch_size self.n_classes = n_classes self.input_height = input_height self.input_width = input_width self.train = train self.reset() def __iter__(self): return self def reset(self): self.cur_batch = 0 self.shuffled_items = [] index = [n for n in range(len(self.items))] if self.train: random.shuffle(index) for i in range(len(self.items)): self.shuffled_items.append(self.items[index[i]]) def __next__(self): return self.next() @property def provide_data(self): return self._provide_data @property def provide_label(self): return self._provide_label def next(self): if self.cur_batch == 0: print("") print("\r\033[k"+("Training " if self.train else "Validating ")+str(self.cur_batch)+"/"+str(self.num_batches), end=' ') if self.cur_batch < self.num_batches: data, label = get_batch(self.shuffled_items[self.cur_batch * self.batch_size:(self.cur_batch + 1) * self.batch_size], self.root_path, self.n_classes, self.input_height, self.input_width) self.cur_batch += 1 return mx.io.DataBatch([data], [label]) else: raise StopIteration if __name__ =='__main__': root_path = '/datasets/' train_file = './data/seg_train.txt' val_file = './data/seg_test.txt' batch_size = 16 n_classes = 2 img_width = 256 img_height = 256 trainiter = UnetDataIter(root_path, train_file, batch_size, n_classes, img_width, img_height, True) while True: trainiter.next()
unet.py
import os
os.environ["MXNET_BACKWARD_DO_MIRROR"] = "1"
os.environ["MXNET_CUDNN_AUTOTUNE_DEFAULT"] = "0"
import mxnet as mx
from mxnet import ndarray as F
from skimage.transform import resize
from skimage.io import imsave
import numpy as np
from unetdataiter import UnetDataIter
import matplotlib.pyplot as plt
def dice_coef(y_true, y_pred):
intersection = mx.sym.sum(mx.sym.broadcast_mul(y_true, y_pred), axis=(1, 2, 3))
return mx.sym.broadcast_div((2. * intersection + 1.),(mx.sym.sum(y_true, axis=(1, 2, 3)) + mx.sym.sum(y_pred, axis=(1, 2, 3)) + 1.))
def dice_coef_loss(y_true, y_pred):
intersection = mx.sym.sum(mx.sym.broadcast_mul(y_true, y_pred), axis=1, )
return -mx.sym.broadcast_div((2. * intersection + 1.),(mx.sym.broadcast_add(mx.sym.sum(y_true, axis=1), mx.sym.sum(y_pred, axis=1)) + 1.))
def build_unet(batch_size, input_width, input_height, train=True):
data = mx.sym.Variable(name='data')
label = mx.sym.Variable(name='softmax_label')
# encode
# 256x256
conv1 = mx.sym.Convolution(data, num_filter=64, kernel=(3,3), pad=(1,1), name='conv1_1')
conv1 = mx.sym.BatchNorm(conv1, name='bn1_1')
conv1 = mx.sym.Activation(conv1, act_type='relu', name='relu1_1')
conv1 = mx.sym.Convolution(conv1, num_filter=64, kernel=(3,3), pad=(1,1), name='conv1_2')
conv1 = mx.sym.BatchNorm(conv1, name='bn1_2')
conv1 = mx.sym.Activation(conv1, act_type='relu', name='relu1_2')
pool1 = mx.sym.Pooling(conv1, kernel=(2,2), pool_type='max', name='pool1')
# 128x128
conv2 = mx.sym.Convolution(pool1, num_filter=128, kernel=(3,3), pad=(1,1), name='conv2_1')
conv2 = mx.sym.BatchNorm(conv2, name='bn2_1')
conv2 = mx.sym.Activation(conv2, act_type='relu', name='relu2_1')
conv2 = mx.sym.Convolution(conv2, num_filter=128, kernel=(3,3), pad=(1,1), name='conv2_2')
conv2 = mx.sym.BatchNorm(conv2, name='bn2_2')
conv2 = mx.sym.Activation(conv2, act_type='relu', name='relu2_2')
pool2 = mx.sym.Pooling(conv2, kernel=(2,2), pool_type='max', name='pool2')
# 64x64
conv3 = mx.sym.Convolution(pool2, num_filter=256, kernel=(3,3), pad=(1,1), name='conv3_1')
conv3 = mx.sym.BatchNorm(conv3, name='bn3_1')
conv3 = mx.sym.Activation(conv3, act_type='relu', name='relu3_1')
conv3 = mx.sym.Convolution(conv3, num_filter=256, kernel=(3,3), pad=(1,1), name='conv3_2')
conv3 = mx.sym.BatchNorm(conv3, name='bn3_2')
conv3 = mx.sym.Activation(conv3, act_type='relu', name='relu3_2')
pool3 = mx.sym.Pooling(conv3, kernel=(2,2), pool_type='max', name='pool3')
# 32x32
conv4 = mx.sym.Convolution(pool3, num_filter=256, kernel=(3,3), pad=(1,1), name='conv4_1')
conv4 = mx.sym.BatchNorm(conv4, name='bn4_1')
conv4 = mx.sym.Activation(conv4, act_type='relu', name='relu4_1')
conv4 = mx.sym.Convolution(conv4, num_filter=256, kernel=(3,3), pad=(1,1), name='conv4_2')
conv4 = mx.sym.BatchNorm(conv4, name='bn4_2')
conv4 = mx.sym.Activation(conv4, act_type='relu', name='relu4_2')
pool4 = mx.sym.Pooling(conv4, kernel=(2,2), pool_type='max', name='pool4')
# 16x16
conv5 = mx.sym.Convolution(pool4, num_filter=256, kernel=(3,3), pad=(1,1), name='conv5_1')
conv5 = mx.sym.BatchNorm(conv5, name='bn5_1')
conv5 = mx.sym.Activation(conv5, act_type='relu', name='relu5_1')
conv5 = mx.sym.Convolution(conv5, num_filter=256, kernel=(3,3), pad=(1,1), name='conv5_2')
conv5 = mx.sym.BatchNorm(conv5, name='bn5_2')
conv5 = mx.sym.Activation(conv5, act_type='relu', name='relu5_2')
pool5 = mx.sym.Pooling(conv5, kernel=(2,2), pool_type='max', name='pool5')
# 8x8
# decode
trans_conv6 = mx.sym.Deconvolution(pool5, num_filter=256, kernel=(2,2), stride=(1,1), no_bias=True, name='trans_conv6')
up6 = mx.sym.concat(*[trans_conv6, conv5], dim=1, name='concat6')
conv6 = mx.sym.Convolution(up6, num_filter=256, kernel=(3,3), pad=(1,1), name='conv6_1')
conv6 = mx.sym.BatchNorm(conv6, name='bn6_1')
conv6 = mx.sym.Activation(conv6, act_type='relu', name='relu6_1')
conv6 = mx.sym.Convolution(conv6, num_filter=256, kernel=(3,3), pad=(1,1), name='conv6_2')
conv6 = mx.sym.BatchNorm(conv6, name='bn6_2')
conv6 = mx.sym.Activation(conv6, act_type='relu', name='relu6_2')
trans_conv7 = mx.sym.Deconvolution(conv6, num_filter=256, kernel=(2,2), stride=(1,1), no_bias=True, name='trans_conv7')
up7 = mx.sym.concat(*[trans_conv7, conv4], dim=1, name='concat7')
conv7 = mx.sym.Convolution(up7, num_filter=256, kernel=(3,3), pad=(1,1), name='conv7_1')
conv7 = mx.sym.BatchNorm(conv7, name='bn7_1')
conv7 = mx.sym.Activation(conv7, act_type='relu', name='relu7_1')
conv7 = mx.sym.Convolution(conv7, num_filter=256, kernel=(3,3), pad=(1,1), name='conv7_2')
conv7 = mx.sym.BatchNorm(conv7, name='bn7_2')
conv7 = mx.sym.Activation(conv7, act_type='relu', name='relu7_2')
trans_conv8 = mx.sym.Deconvolution(conv7, num_filter=256, kernel=(2,2), stride=(1,1), no_bias=True, name='trans_conv8')
up8 = mx.sym.concat(*[trans_conv8, conv3], dim=1, name='concat8')
conv8 = mx.sym.Convolution(up8, num_filter=256, kernel=(3,3), pad=(1,1), name='conv8_1')
conv8 = mx.sym.BatchNorm(conv8, name='bn8_1')
conv8 = mx.sym.Activation(conv8, act_type='relu', name='relu8_1')
conv8 = mx.sym.Convolution(conv8, num_filter=256, kernel=(3,3), pad=(1,1), name='conv8_2')
conv8 = mx.sym.BatchNorm(conv8, name='bn8_2')
conv8 = mx.sym.Activation(conv8, act_type='relu', name='relu8_2')
trans_conv9 = mx.sym.Deconvolution(conv8, num_filter=128, kernel=(2,2), stride=(1,1), no_bias=True, name='trans_conv9')
up9 = mx.sym.concat(*[trans_conv9, conv2], dim=1, name='concat9')
conv9 = mx.sym.Convolution(up9, num_filter=128, kernel=(3,3), pad=(1,1), name='conv9_1')
conv9 = mx.sym.BatchNorm(conv9, name='bn9_1')
conv9 = mx.sym.Activation(conv9, act_type='relu', name='relu9_1')
conv9 = mx.sym.Convolution(conv9, num_filter=128, kernel=(3,3), pad=(1,1), name='conv9_2')
conv9 = mx.sym.BatchNorm(conv9, name='bn9_2')
conv9 = mx.sym.Activation(conv9, act_type='relu', name='relu9_2')
trans_conv10 = mx.sym.Deconvolution(conv9, num_filter=64, kernel=(2,2), stride=(1,1), no_bias=True, name='trans_conv10')
up10 = mx.sym.concat(*[trans_conv10, conv1], dim=1, name='concat10')
conv10 = mx.sym.Convolution(up10, num_filter=64, kernel=(3,3), pad=(1,1), name='conv10_1')
conv10 = mx.sym.BatchNorm(conv10, name='bn10_1')
conv10 = mx.sym.Activation(conv10, act_type='relu', name='relu10_1')
conv10 = mx.sym.Convolution(conv10, num_filter=64, kernel=(3,3), pad=(1,1), name='conv10_2')
conv10 = mx.sym.BatchNorm(conv10, name='bn10_2')
conv10 = mx.sym.Activation(conv10, act_type='relu', name='relu10_2')
###
conv11 = mx.sym.Convolution(conv10, num_filter=2, kernel=(1,1), name='conv11_1')
conv11 = mx.sym.sigmoid(conv11, name='softmax')
net = mx.sym.Reshape(conv11, (batch_size, 2, input_width*input_height))
if train:
loss = mx.sym.MakeLoss(dice_coef_loss(label, net), normalization='batch')
mask_output = mx.sym.BlockGrad(conv11, 'mask')
out = mx.sym.Group([loss, mask_output])
else:
# mask_output = mx.sym.BlockGrad(conv11, 'mask')
out = mx.sym.Group([conv11])
return out
trainunet.py
import os
os.environ["MXNET_BACKWARD_DO_MIRROR"] = "1"
os.environ["MXNET_CUDNN_AUTOTUNE_DEFAULT"] = "0"
import mxnet as mx
from mxnet import ndarray as F
from skimage.transform import resize
from skimage.io import imsave
import numpy as np
from unetdataiter import UnetDataIter
import matplotlib.pyplot as plt
from unet import build_unet
def main():
root_path = '../datasets/'
train_file = './data/seg_train.txt'
val_file = './data/seg_test.txt'
batch_size = 16
n_classes = 2
# img_width = 256
# img_height = 256
img_width = 96
img_height = 96
train_iter = UnetDataIter(root_path, train_file, batch_size, n_classes, img_width, img_height, True)
val_iter = UnetDataIter(root_path, val_file, batch_size, n_classes, img_width, img_height, False)
ctx = [mx.gpu(0)]
unet_sym = build_unet(batch_size, img_width, img_height)
unet = mx.mod.Module(unet_sym, context=ctx, data_names=('data',), label_names=('softmax_label',))
unet.bind(data_shapes=[['data', (batch_size, 3, img_width, img_height)]], label_shapes=[['softmax_label', (batch_size, n_classes, img_width*img_height)]])
unet.init_params(mx.initializer.Xavier(magnitude=6))
unet.init_optimizer(optimizer = 'adam',
optimizer_params=(
('learning_rate', 1E-4),
('beta1', 0.9),
('beta2', 0.99)
))
# unet.fit(train_iter, # train data
# eval_data=val_iter, # validation data
# #optimizer='sgd', # use SGD to train
# #optimizer_params={'learning_rate':0.1}, # use fixed learning rate
# eval_metric='acc', # report accuracy during training
# batch_end_callback = mx.callback.Speedometer(batch_size, 1), # output progress for each 100 data batches
# num_epoch=10) # train for at most 10 dataset passes
epochs = 20
smoothing_constant = .01
curr_losses = []
moving_losses = []
i = 0
best_val_loss = np.inf
for e in range(epochs):
while True:
try:
batch = next(train_iter)
except StopIteration:
train_iter.reset()
break
unet.forward_backward(batch)
loss = unet.get_outputs()[0]
unet.update()
curr_loss = F.mean(loss).asscalar()
curr_losses.append(curr_loss)
moving_loss = (curr_loss if ((i == 0) and (e == 0))
else (1 - smoothing_constant) * moving_loss + (smoothing_constant) * curr_loss)
moving_losses.append(moving_loss)
i += 1
val_losses = []
for batch in val_iter:
unet.forward(batch)
loss = unet.get_outputs()[0]
val_losses.append(F.mean(loss).asscalar())
val_iter.reset()
val_loss = np.mean(val_losses)
print("\nEpoch %i: Moving Training Loss %0.5f, Validation Loss %0.5f" % (e, moving_loss, val_loss))
unet.save_checkpoint('./unet_person_segmentation', e)
if __name__ =='__main__':
main()
以上是訓練程式碼。
預測程式碼如下predict.py
import os
os.environ["MXNET_BACKWARD_DO_MIRROR"] = "1"
os.environ["MXNET_CUDNN_AUTOTUNE_DEFAULT"] = "0"
import sys
import cv2
import mxnet as mx
from mxnet import ndarray as F
from skimage.transform import resize
from skimage.io import imsave
import numpy as np
from unetdataiter import UnetDataIter
import matplotlib.pyplot as plt
from unet import build_unet
def post_process_mask(label, img_cols, img_rows, n_classes, p=0.5):
pr = label.reshape(n_classes, img_cols, img_rows).transpose([1,2,0]).argmax(axis=2)
return (pr*255).asnumpy()
def load_image(img, width, height):
im = np.zeros((height, width, 3), dtype='uint8')
im[:, :, :] = 128
if img.shape[0] >= img.shape[1]:
scale = img.shape[0] / height
new_width = int(img.shape[1] / scale)
diff = (width - new_width) // 2
img = cv2.resize(img, (new_width, height))
im[:, diff:diff + new_width, :] = img
else:
scale = img.shape[1] / width
new_height = int(img.shape[0] / scale)
diff = (height - new_height) // 2
img = cv2.resize(img, (width, new_height))
im[diff:diff + new_height, :, :] = img
im = np.float32(im) / 127.5 - 1
return [im.transpose((2,0,1))]
def main():
batch_size = 16
n_classes = 2
# img_width = 256
# img_height = 256
img_width = 96
img_height = 96
ctx = [mx.gpu(0)]
# sym, arg_params, aux_params = mx.model.load_checkpoint('unet_person_segmentation', 20)
# unet_sym = build_unet(batch_size, img_width, img_height, False)
# unet = mx.mod.Module(symbol=unet_sym, context=ctx, label_names=None)
sym, arg_params, aux_params = mx.model.load_checkpoint('unet_person_segmentation', 0)
unet = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
unet.bind(for_training=False, data_shapes=[['data', (batch_size, 3, img_width, img_height)]], label_shapes=unet._label_shapes)
unet.set_params(arg_params, aux_params, allow_missing=True)
testimg = cv2.imread(sys.argv[1], 1)
img = load_image(testimg, img_width, img_height)
unet.predict(mx.io.NDArrayIter(data=[img]))
outputs = unet.get_outputs()[0]
cv2.imshow('test', testimg)
cv2.imshow('mask', post_process_mask(outputs[0], img_width, img_height, n_classes))
cv2.waitKey()
if __name__ == '__main__':
if len(sys.argv) < 2:
print("illegal parameters")
sys.exit(0)
main()
剝離softmax儲存引數用於ncnn模型轉換,train2infer.py
import os
os.environ["MXNET_BACKWARD_DO_MIRROR"] = "1"
os.environ["MXNET_CUDNN_AUTOTUNE_DEFAULT"] = "0"
import sys
import cv2
import mxnet as mx
from mxnet import ndarray as F
from skimage.transform import resize
from skimage.io import imsave
import numpy as np
from unetdataiter import UnetDataIter
import matplotlib.pyplot as plt
from unet import build_unet
def main():
batch_size = 16
n_classes = 2
# img_width = 256
# img_height = 256
img_width = 96
img_height = 96
ctx = [mx.gpu(0)]
sym, arg_params, aux_params = mx.model.load_checkpoint(sys.argv[1], int(sys.argv[2]))
unet_sym = build_unet(batch_size, img_width, img_height, False)
unet = mx.mod.Module(symbol=unet_sym, context=ctx, label_names=None)
unet.bind(for_training=False, data_shapes=[['data', (batch_size, 3, img_width, img_height)]], label_shapes=unet._label_shapes)
unet.set_params(arg_params, aux_params, allow_missing=True)
unet.save_checkpoint('./unet_person_segmentation', 0)
if __name__ == '__main__':
if len(sys.argv) < 3:
print("illegal parameters")
sys.exit(0)
main()