Keras實現Unet結構
在之前發的部落格“基於卷積神經網路特徵圖的二值影象分割”中(https://blog.csdn.net/shi2xian2wei2/article/details/84329511)也提到,Unet結構主要是通過多個多通道特徵圖最大化的利用輸入圖片的特徵,使得網路在訓練集較小的情況下也能夠得到較好的目標分割結果。Unet論文見https://arxiv.org/abs/1505.04597,這裡通過Keras框架對Unet結構進行搭建,並使用之前部落格中的偽造資料集對網路進行訓練以及測試。
Unet論文中所提出的網路結構如下圖所示:
Unet大量使用了拼接結構,以實現對影象不同尺度資訊的採集,這樣做也是為了能儘可能利用圖片中的資訊,論文中卷積的padding採用的是vaild方式,因此在進行拼接時需要對輸出進行裁剪來保證尺寸的一致性。個人感覺padding選用same方式不會對效能產生任何不好的影響,並且實現起來也更加方便。keras框架實現的網路結構如下:
inpt = Input(shape=(input_size_1, input_size_2, 3)) conv1 = Conv2d_BN(inpt, 8, (3, 3)) conv1 = Conv2d_BN(conv1, 8, (3, 3)) pool1 = MaxPooling2D(pool_size=(2,2),strides=(2,2),padding='same')(conv1) conv2 = Conv2d_BN(pool1, 16, (3, 3)) conv2 = Conv2d_BN(conv2, 16, (3, 3)) pool2 = MaxPooling2D(pool_size=(2,2),strides=(2,2),padding='same')(conv2) conv3 = Conv2d_BN(pool2, 32, (3, 3)) conv3 = Conv2d_BN(conv3, 32, (3, 3)) pool3 = MaxPooling2D(pool_size=(2,2),strides=(2,2),padding='same')(conv3) conv4 = Conv2d_BN(pool3, 64, (3, 3)) conv4 = Conv2d_BN(conv4, 64, (3, 3)) pool4 = MaxPooling2D(pool_size=(2,2),strides=(2,2),padding='same')(conv4) conv5 = Conv2d_BN(pool4, 128, (3, 3)) conv5 = Dropout(0.5)(conv5) conv5 = Conv2d_BN(pool4, 128, (3, 3)) conv5 = Dropout(0.5)(conv5) convt1 = Conv2dT_BN(conv5, 64, (3, 3)) concat1 = concatenate([conv4, convt1], axis=3) concat1 = Dropout(0.5)(concat1) conv6 = Conv2d_BN(concat1, 64, (3, 3)) conv6 = Conv2d_BN(conv6, 64, (3, 3)) convt2 = Conv2dT_BN(conv6, 32, (3, 3)) concat2 = concatenate([conv3, convt2], axis=3) concat2 = Dropout(0.5)(concat2) conv7 = Conv2d_BN(concat2, 32, (3, 3)) conv7 = Conv2d_BN(conv7, 32, (3, 3)) convt3 = Conv2dT_BN(conv7, 16, (3, 3)) concat3 = concatenate([conv2, convt3], axis=3) concat3 = Dropout(0.5)(concat3) conv8 = Conv2d_BN(concat3, 16, (3, 3)) conv8 = Conv2d_BN(conv8, 16, (3, 3)) convt4 = Conv2dT_BN(conv8, 8, (3, 3)) concat4 = concatenate([conv1, convt4], axis=3) concat4 = Dropout(0.5)(concat4) conv9 = Conv2d_BN(concat4, 8, (3, 3)) conv9 = Conv2d_BN(conv9, 8, (3, 3)) conv9 = Dropout(0.5)(conv9) outpt = Conv2D(filters=3, kernel_size=(1,1), strides=(1,1), padding='same', activation='sigmoid')(conv9) model = Model(inpt, outpt) model.compile(loss='mean_squared_error', optimizer='Nadam', metrics=['accuracy']) model.summary()
構建的網路引數進行了大幅度減少,主要是我的電腦顯示卡不給力/(ㄒoㄒ)/~~……加入大量的Dropout層是為了防止網路過擬合,因為樣本數量比較少。網路之所以使用最大池化層進行下采樣我覺得主要是考慮到對邊緣特徵的最大化利用。在1000張圖片的訓練集上訓練約22個Epoch後,結果如下:
原始影象 真實標籤 檢測標籤
雖然在訓練過程中,訓練樣本中並沒有包含任何的紋理資訊,網路輸出的結果中可以看到部分物體的一些紋理。這也從一個方面反映了Unet特徵提取能力的強大。
附上網路完整程式碼,資料還請自行替換:
import numpy as np
import random
import os
from keras.models import save_model, load_model, Model
from keras.layers import Input, Dropout, BatchNormalization, LeakyReLU, concatenate
from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D, Conv2DTranspose
import matplotlib.pyplot as plt
from skimage import io
from skimage.transform import resize
input_name = os.listdir('train_data3/JPEGImages')
n = len(input_name)
batch_size = 8
input_size_1 = 256
input_size_2 = 256
def batch_data(input_name, n, batch_size = 8, input_size_1 = 256, input_size_2 = 256):
rand_num = random.randint(0, n-1)
img1 = io.imread('train_data3/JPEGImages/'+input_name[rand_num]).astype("float")
img2 = io.imread('train_data3/TargetImages/'+input_name[rand_num]).astype("float")
img1 = resize(img1, [input_size_1, input_size_2, 3])
img2 = resize(img2, [input_size_1, input_size_2, 3])
img1 = np.reshape(img1, (1, input_size_1, input_size_2, 3))
img2 = np.reshape(img2, (1, input_size_1, input_size_2, 3))
img1 /= 255
img2 /= 255
batch_input = img1
batch_output = img2
for batch_iter in range(1, batch_size):
rand_num = random.randint(0, n-1)
img1 = io.imread('train_data3/JPEGImages/'+input_name[rand_num]).astype("float")
img2 = io.imread('train_data3/TargetImages/'+input_name[rand_num]).astype("float")
img1 = resize(img1, [input_size_1, input_size_2, 3])
img2 = resize(img2, [input_size_1, input_size_2, 3])
img1 = np.reshape(img1, (1, input_size_1, input_size_2, 3))
img2 = np.reshape(img2, (1, input_size_1, input_size_2, 3))
img1 /= 255
img2 /= 255
batch_input = np.concatenate((batch_input, img1), axis = 0)
batch_output = np.concatenate((batch_output, img2), axis = 0)
return batch_input, batch_output
def Conv2d_BN(x, nb_filter, kernel_size, strides=(1,1), padding='same'):
x = Conv2D(nb_filter, kernel_size, strides=strides, padding=padding)(x)
x = BatchNormalization(axis=3)(x)
x = LeakyReLU(alpha=0.1)(x)
return x
def Conv2dT_BN(x, filters, kernel_size, strides=(2,2), padding='same'):
x = Conv2DTranspose(filters, kernel_size, strides=strides, padding=padding)(x)
x = BatchNormalization(axis=3)(x)
x = LeakyReLU(alpha=0.1)(x)
return x
inpt = Input(shape=(input_size_1, input_size_2, 3))
conv1 = Conv2d_BN(inpt, 8, (3, 3))
conv1 = Conv2d_BN(conv1, 8, (3, 3))
pool1 = MaxPooling2D(pool_size=(2,2),strides=(2,2),padding='same')(conv1)
conv2 = Conv2d_BN(pool1, 16, (3, 3))
conv2 = Conv2d_BN(conv2, 16, (3, 3))
pool2 = MaxPooling2D(pool_size=(2,2),strides=(2,2),padding='same')(conv2)
conv3 = Conv2d_BN(pool2, 32, (3, 3))
conv3 = Conv2d_BN(conv3, 32, (3, 3))
pool3 = MaxPooling2D(pool_size=(2,2),strides=(2,2),padding='same')(conv3)
conv4 = Conv2d_BN(pool3, 64, (3, 3))
conv4 = Conv2d_BN(conv4, 64, (3, 3))
pool4 = MaxPooling2D(pool_size=(2,2),strides=(2,2),padding='same')(conv4)
conv5 = Conv2d_BN(pool4, 128, (3, 3))
conv5 = Dropout(0.5)(conv5)
conv5 = Conv2d_BN(pool4, 128, (3, 3))
conv5 = Dropout(0.5)(conv5)
convt1 = Conv2dT_BN(conv5, 64, (3, 3))
concat1 = concatenate([conv4, convt1], axis=3)
concat1 = Dropout(0.5)(concat1)
conv6 = Conv2d_BN(concat1, 64, (3, 3))
conv6 = Conv2d_BN(conv6, 64, (3, 3))
convt2 = Conv2dT_BN(conv6, 32, (3, 3))
concat2 = concatenate([conv3, convt2], axis=3)
concat2 = Dropout(0.5)(concat2)
conv7 = Conv2d_BN(concat2, 32, (3, 3))
conv7 = Conv2d_BN(conv7, 32, (3, 3))
convt3 = Conv2dT_BN(conv7, 16, (3, 3))
concat3 = concatenate([conv2, convt3], axis=3)
concat3 = Dropout(0.5)(concat3)
conv8 = Conv2d_BN(concat3, 16, (3, 3))
conv8 = Conv2d_BN(conv8, 16, (3, 3))
convt4 = Conv2dT_BN(conv8, 8, (3, 3))
concat4 = concatenate([conv1, convt4], axis=3)
concat4 = Dropout(0.5)(concat4)
conv9 = Conv2d_BN(concat4, 8, (3, 3))
conv9 = Conv2d_BN(conv9, 8, (3, 3))
conv9 = Dropout(0.5)(conv9)
outpt = Conv2D(filters=3, kernel_size=(1,1), strides=(1,1), padding='same', activation='sigmoid')(conv9)
model = Model(inpt, outpt)
model.compile(loss='mean_squared_error', optimizer='Nadam', metrics=['accuracy'])
model.summary()
itr = 3000
S = []
for i in range(itr):
print("iteration = ", i+1)
if i < 500:
bs = 4
elif i < 2000:
bs = 8
elif i < 5000:
bs = 16
else:
bs = 32
train_X, train_Y = batch_data(input_name, n, batch_size = bs)
model.fit(train_X, train_Y, epochs=1, verbose=0)
if i % 100 == 99:
save_model(model, 'unet.h5')
model = load_model('unet.h5')
def batch_data_test(input_name, n, batch_size = 8, input_size_1 = 256, input_size_2 = 256):
rand_num = random.randint(0, n-1)
img1 = io.imread('test_data3/JPEGImages/'+input_name[rand_num]).astype("float")
img2 = io.imread('test_data3/TargetImages/'+input_name[rand_num]).astype("float")
img1 = resize(img1, [input_size_1, input_size_2, 3])
img2 = resize(img2, [input_size_1, input_size_2, 3])
img1 = np.reshape(img1, (1, input_size_1, input_size_2, 3))
img2 = np.reshape(img2, (1, input_size_1, input_size_2, 3))
img1 /= 255
img2 /= 255
batch_input = img1
batch_output = img2
for batch_iter in range(1, batch_size):
rand_num = random.randint(0, n-1)
img1 = io.imread('test_data3/JPEGImages/'+input_name[rand_num]).astype("float")
img2 = io.imread('test_data3/TargetImages/'+input_name[rand_num]).astype("float")
img1 = resize(img1, [input_size_1, input_size_2, 3])
img2 = resize(img2, [input_size_1, input_size_2, 3])
img1 = np.reshape(img1, (1, input_size_1, input_size_2, 3))
img2 = np.reshape(img2, (1, input_size_1, input_size_2, 3))
img1 /= 255
img2 /= 255
batch_input = np.concatenate((batch_input, img1), axis = 0)
batch_output = np.concatenate((batch_output, img2), axis = 0)
return batch_input, batch_output
test_name = os.listdir('test_data3/JPEGImages')
n_test = len(test_name)
test_X, test_Y = batch_data_test(test_name, n_test, batch_size = 1)
pred_Y = model.predict(test_X)
ii = 0
plt.figure()
plt.imshow(test_X[ii, :, :, :])
plt.axis('off')
plt.figure()
plt.imshow(test_Y[ii, :, :, :])
plt.axis('off')
plt.figure()
plt.imshow(pred_Y[ii, :, :, :])
plt.axis('off')