pytorch實現影象風格遷移
阿新 • • 發佈:2018-12-22
首先定義兩個損失函式:
內容損失:
class Content_loss(torch.nn.Module): # weight權重 控制風格的影響程度 target經過卷積獲取到的輸入影象的內容。 def __init__(self, weight, target): super(Content_loss, self).__init__() self.weight = weight # detach對提取的內容進行鎖定,不進行梯度 self.target = target.detach() * weight # 用均方誤差作為損失函式 self.loss_fn = torch.nn.MSELoss() # 計算影象與內容之間的損失值 def forward(self, input): self.loss = self.loss_fn(input * self.weight, self.target) return input # 計算損失值向後傳播 def backward(self): self.loss.backward(retain_graph=True) return self.loss
風格損失:
''' 影象風格損失 ''' class Style_loss(torch.nn.Module): # weight權重 控制風格的影響程度 target經過卷積獲取到的輸入影象的內容。 def __init__(self, weight, target): super(Style_loss, self).__init__() self.weight = weight self.target = target.detach() * weight # 用均方誤差作為損失函式 self.loss_fn = torch.nn.MSELoss() self.gram = Gram_matrix() # 計算影象與內容之間的損失值 def forward(self, input): self.Gram = self.gram(input.clone()) self.Gram.mul_(self.weight) self.loss = self.loss_fn(self.Gram, self.target) return input # 計算損失值向後傳播 def backward(self): self.loss.backward(retain_graph=True) return self.loss '''用這個類定義的例項參與風格損失的計算 格拉姆矩陣 卷積-》影象風格(由數字組成) 相當於進行內積運算 放大圖片風格在進行損失計算,能對合成的圖片產生更大的影響 ''' class Gram_matrix(torch.nn.Module): def forward(self, input): a, b, c, d = input.size() print("a",a,"b",b,"c",c,"d",d) # 轉為(ab行 cd列) feature = input.view(a * b, c * d) # 內積運算 gram = torch.mm(feature, feature.t()) # 除以abcd return gram.div(a * b * c * d)
搭建網路模型
'''影象風格遷移模型''' new_model = torch.nn.Sequential() # 深層複製,改變任意一個都不變 # 淺層複製 改變原來的cnn model會變 model = copy.deepcopy(cnn) gram = Gram_matrix() if (use_gpu): new_model = new_model.cuda() gram = gram.cuda() index = 1 # 僅用到遷移模型提取特徵的前八層 for layer in list(model)[:8]: # 例項檢測函式檢測 if isinstance(layer, torch.nn.Conv2d): name = "Conv_" + str(index) # 向空模型中加入指定的層次模組,得到自定義模型 new_model.add_module(name, layer) if name in content_layer: target = new_model(content_img).clone() content_loss = Content_loss(content_weight, target) new_model.add_module("content_loss_" + str(index), content_loss) content_losses.append(content_loss) if name in style_layer: target = new_model(style_img).clone() target = gram(target) style_loss = Style_loss(style_weight, target) new_model.add_module("style_loss_" + str(index), style_loss) style_losses.append(style_loss) if isinstance(layer, torch.nn.ReLU): name = "ReLU_" + str(index) new_model.add_module(name, layer) index = index + 1 if isinstance(layer, torch.nn.MaxPool2d): name = "MaxPool2d_" + str(index) new_model.add_module(name, layer)
輸出自定義的網路結構
Sequential(
(Conv_1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(style_loss_1): Style_loss(
(loss_fn): MSELoss()
(gram): Gram_matrix()
)
(ReLU_1): ReLU(inplace)
(Conv_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(style_loss_2): Style_loss(
(loss_fn): MSELoss()
(gram): Gram_matrix()
)
(ReLU_2): ReLU(inplace)
(MaxPool2d_3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(Conv_3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(content_loss_3): Content_loss(
(loss_fn): MSELoss()
)
(style_loss_3): Style_loss(
(loss_fn): MSELoss()
(gram): Gram_matrix()
)
(ReLU_3): ReLU(inplace)
(Conv_4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(style_loss_4): Style_loss(
(loss_fn): MSELoss()
(gram): Gram_matrix()
)
)
訓練
optimizer = torch.optim.LBFGS([parameter])
epoch_n = 300
epoch = [0]
while epoch[0] <= epoch_n:
def closure():
optimizer.zero_grad()
style_score = 0
content_score = 0
parameter.data.clamp_(0, 1)
new_model(parameter)
for sl in style_losses:
style_score += sl.backward()
for cl in content_losses:
content_score += cl.backward()
epoch[0] += 1
if epoch[0] % 50 == 0:
print("Epoch:{} StyleLoss :{:4f} Content Loss:{:4f}".format(
epoch[0], style_score.data[0], content_score.data[0]))
img_cs = new_model(parameter)
plt.figure("Img_cs")
plt.imshow(img_cs)
plt.show()
return style_score + content_score
optimizer.step(closure)
結果