keras學習筆記-黑白照片自動著色的神經網路-Beta版
正文共3894個字,8張圖,預計閱讀時間11分鐘。
Alpha版本不能很好地給未經訓練的影象著色。接下來,我們將在Beta版本中做到這一點——將上面的將神經網路泛化。
以下是使用Beta版本對測試影象著色的結果。
特徵提取器
我們的神經網路要做的是發現將灰度影象與其彩色版本相連結的特徵。
試想,你必須給黑白影象上色,但一次只能看到9個畫素。你可以從左上角到右下角掃描每個影象,並嘗試預測每個畫素應該是什麼顏色。
例如,這9個畫素就是上面那張女性人臉照片上鼻孔的邊緣。要很好的著色幾乎是不可能的,所以你必須把它分解成好幾個步驟。
首先,尋找簡單的模式:對角線,所有黑色畫素等。在每個濾波器的掃描方塊中尋找相同的精確的模式,並刪除不匹配的畫素。這樣,就可以從64個迷你濾波器生成64個新影象。
如果再次掃描影象,你會看到已經檢測到的相同的模式。要獲得對影象更高級別的理解,你可以將影象尺寸減小一半。
你仍然只有3×3個濾波器來掃描每個影象。但是,通過將新的9個畫素與較低級別的濾波器相結合,可以檢測更復雜的圖案。一個畫素組合可能形成一個半圓,一個小點或一條線。再一次地,你從影象中反覆提取相同的圖案。這次,你會生成128個新的過濾影象。
經過幾個步驟,生成的過濾影象可能看起來像這樣:
這個過程就像大多數處理視覺的神經網路,也即卷積神經網路的行為。結合幾個過濾影象瞭解影象中的上下文。
from keras.layers import Conv2D, Conv2DTranspose, UpSampling2D
from keras.layers import Activation, Dense, Dropout, Flatten, InputLayer
from keras.layers.normalization import BatchNormalization
from keras.callbacks import TensorBoard
from keras.models import Sequential
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
from skimage.color import rgb2lab, lab2rgb, rgb2gray
from skimage.io import imsave
import numpy as np
import osimport random
import tensorflow as tf
Using TensorFlow backend.
# Get imagesX = []
for filename in os.listdir('data/color/Train/'):
X.append(img_to_array(load_img('data/color/Train/'+filename))) X = np.array(X, dtype=float)
# Set up train and test data
split = int(0.95*len(X)) Xtrain = X[:split] Xtrain = 1.0/255*Xtrainmodel = Sequential()
model.add(InputLayer(input_shape=(256, 256, 1))) model.add(Conv2D(64, (3, 3), activation='relu', padding='same')) model.add(Conv2D(64, (3, 3), activation='relu', padding='same', strides=2)) model.add(Conv2D(128, (3, 3), activation='relu', padding='same')) model.add(Conv2D(128, (3, 3), activation='relu', padding='same', strides=2)) model.add(Conv2D(256, (3, 3), activation='relu', padding='same')) model.add(Conv2D(256, (3, 3), activation='relu', padding='same', strides=2)) model.add(Conv2D(512, (3, 3), activation='relu', padding='same')) model.add(Conv2D(256, (3, 3), activation='relu', padding='same')) model.add(Conv2D(128, (3, 3), activation='relu', padding='same')) model.add(UpSampling2D((2, 2))) model.add(Conv2D(64, (3, 3), activation='relu', padding='same')) model.add(UpSampling2D((2, 2))) model.add(Conv2D(32, (3, 3), activation='relu', padding='same')) model.add(Conv2D(2, (3, 3), activation='tanh', padding='same')) model.add(UpSampling2D((2, 2))) model.compile(optimizer='rmsprop', loss='mse')
# Image transformerdatagen = ImageDataGenerator( shear_range=0.2, zoom_range=0.2, rotation_range=20, horizontal_flip=True)
# Generate training databatch_size = 10def image_a_b_gen(batch_size): for batch in datagen.flow(Xtrain, batch_size=batch_size): lab_batch = rgb2lab(batch) X_batch = lab_batch[:,:,:,0] Y_batch = lab_batch[:,:,:,1:] / 128 yield (X_batch.reshape(X_batch.shape+(1,)), Y_batch)
# Train model
tensorboard = TensorBoard(log_dir="data/color/output/first_run") model.fit_generator(image_a_b_gen(batch_size), callbacks=[tensorboard], epochs=1, steps_per_epoch=10)
Epoch 1/1 10/10 [==============================] - 178s - loss: 0.5208 <keras.callbacks.History at 0x1092b5ac8>
# Save modelmodel_json = model.to_json()with open("model.json", "w") as json_file: json_file.write(model_json) model.save_weights("model.h5")
# Test imagesXtest = rgb2lab(1.0/255*X[split:])[:,:,:,0]
Xtest = Xtest.reshape(Xtest.shape+(1,)) Ytest = rgb2lab(1.0/255*X[split:])[:,:,:,1:] Ytest = Ytest / 128
print(model.evaluate(Xtest, Ytest, batch_size=batch_size))
1/1 [==============================] - 0s 0.00189386657439
color_me = []for filename in os.listdir('data/color/Test/'): color_me.append(img_to_array(load_img('data/color/Test/'+filename))) color_me = np.array(color_me, dtype=float) color_me = rgb2lab(1.0/255*color_me)[:,:,:,0] color_me = color_me.reshape(color_me.shape+(1,))# Test modeloutput = model.predict(color_me) output = output * 128# Output colorizationsfor i in range(len(output)): cur = np.zeros((256, 256, 3)) cur[:,:,0] = color_me[i][:,:,0] cur[:,:,1:] = output[i] imsave("data/color/output/img1_"+str(i)+".png", lab2rgb(cur))
/usr/local/lib/python3.6/site-packages/skimage/util/dtype.py:122: UserWarning: Possible precision loss when converting from float64 to uint8 .format(dtypeobj_in, dtypeobj_out))
原文連結:https://www.jianshu.com/p/ed4b49c2d16b
查閱更為簡潔方便的分類文章以及最新的課程、產品資訊,請移步至全新呈現的“LeadAI學院官網”:
www.leadai.org
請關注人工智慧LeadAI公眾號,檢視更多專業文章
大家都在看