1. 程式人生 > >pytorch實現影象風格遷移

pytorch實現影象風格遷移

首先定義兩個損失函式:

內容損失:

class Content_loss(torch.nn.Module):
    # weight權重 控制風格的影響程度     target經過卷積獲取到的輸入影象的內容。
    def __init__(self, weight, target):
        super(Content_loss, self).__init__()
        self.weight = weight
        # detach對提取的內容進行鎖定,不進行梯度
        self.target = target.detach() * weight
        # 用均方誤差作為損失函式
        self.loss_fn = torch.nn.MSELoss()

    # 計算影象與內容之間的損失值
    def forward(self, input):
        self.loss = self.loss_fn(input * self.weight, self.target)
        return input

    # 計算損失值向後傳播
    def backward(self):
        self.loss.backward(retain_graph=True)
        return self.loss

風格損失:

'''
影象風格損失
'''


class Style_loss(torch.nn.Module):
    # weight權重 控制風格的影響程度     target經過卷積獲取到的輸入影象的內容。
    def __init__(self, weight, target):
        super(Style_loss, self).__init__()
        self.weight = weight
        self.target = target.detach() * weight
        # 用均方誤差作為損失函式
        self.loss_fn = torch.nn.MSELoss()
        self.gram = Gram_matrix()

    # 計算影象與內容之間的損失值
    def forward(self, input):
        self.Gram = self.gram(input.clone())
        self.Gram.mul_(self.weight)
        self.loss = self.loss_fn(self.Gram, self.target)
        return input

    # 計算損失值向後傳播
    def backward(self):
        self.loss.backward(retain_graph=True)
        return self.loss


'''用這個類定義的例項參與風格損失的計算
格拉姆矩陣
卷積-》影象風格(由數字組成)  相當於進行內積運算
放大圖片風格在進行損失計算,能對合成的圖片產生更大的影響
'''


class Gram_matrix(torch.nn.Module):
    def forward(self, input):
        a, b, c, d = input.size()
        print("a",a,"b",b,"c",c,"d",d)
        # 轉為(ab行 cd列)
        feature = input.view(a * b, c * d)
        # 內積運算
        gram = torch.mm(feature, feature.t())
        # 除以abcd
        return gram.div(a * b * c * d)

搭建網路模型

'''影象風格遷移模型'''
new_model = torch.nn.Sequential()
# 深層複製,改變任意一個都不變
# 淺層複製 改變原來的cnn model會變
model = copy.deepcopy(cnn)
gram = Gram_matrix()

if (use_gpu):
    new_model = new_model.cuda()
    gram = gram.cuda()

index = 1
# 僅用到遷移模型提取特徵的前八層
for layer in list(model)[:8]:
    # 例項檢測函式檢測
    if isinstance(layer, torch.nn.Conv2d):
        name = "Conv_" + str(index)
        # 向空模型中加入指定的層次模組,得到自定義模型
        new_model.add_module(name, layer)
        if name in content_layer:
            target = new_model(content_img).clone()
            content_loss = Content_loss(content_weight, target)
            new_model.add_module("content_loss_" + str(index), content_loss)
            content_losses.append(content_loss)
        if name in style_layer:
            target = new_model(style_img).clone()
            target = gram(target)
            style_loss = Style_loss(style_weight, target)
            new_model.add_module("style_loss_" + str(index), style_loss)
            style_losses.append(style_loss)
    if isinstance(layer, torch.nn.ReLU):
        name = "ReLU_" + str(index)
        new_model.add_module(name, layer)
        index = index + 1
    if isinstance(layer, torch.nn.MaxPool2d):
        name = "MaxPool2d_" + str(index)
        new_model.add_module(name, layer)

輸出自定義的網路結構

Sequential(
  (Conv_1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (style_loss_1): Style_loss(
    (loss_fn): MSELoss()
    (gram): Gram_matrix()
  )
  (ReLU_1): ReLU(inplace)
  (Conv_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (style_loss_2): Style_loss(
    (loss_fn): MSELoss()
    (gram): Gram_matrix()
  )
  (ReLU_2): ReLU(inplace)
  (MaxPool2d_3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (Conv_3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (content_loss_3): Content_loss(
    (loss_fn): MSELoss()
  )
  (style_loss_3): Style_loss(
    (loss_fn): MSELoss()
    (gram): Gram_matrix()
  )
  (ReLU_3): ReLU(inplace)
  (Conv_4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (style_loss_4): Style_loss(
    (loss_fn): MSELoss()
    (gram): Gram_matrix()
  )
)

訓練

optimizer = torch.optim.LBFGS([parameter])

epoch_n = 300
epoch = [0]
while epoch[0] <= epoch_n:
    def closure():
        optimizer.zero_grad()
        style_score = 0
        content_score = 0
        parameter.data.clamp_(0, 1)
        new_model(parameter)
        for sl in style_losses:
            style_score += sl.backward()
        for cl in content_losses:
            content_score += cl.backward()

        epoch[0] += 1
        if epoch[0] % 50 == 0:
            print("Epoch:{} StyleLoss :{:4f} Content Loss:{:4f}".format(
                epoch[0], style_score.data[0], content_score.data[0]))
            img_cs = new_model(parameter)
            plt.figure("Img_cs")
            plt.imshow(img_cs)
            plt.show()
        return style_score + content_score


    optimizer.step(closure)

結果