1. 程式人生 > 其它 >Pytorch風格遷移

Pytorch風格遷移

最近研究了一下風格遷移,主要是想應用於某些主題節日時動態融合背景,生成一些抽象的藝術圖片,這裡給大家分享一個現成的程式碼,我本地把環境搭建好後跑了試試,有興趣的可以直接拿去執行:

  1 import torch
  2 import torch.nn as nn
  3 import torch.nn.functional as F
  4 import torch.optim as optim
  5 
  6 from PIL import Image
  7 import matplotlib.pyplot as plt
  8 
  9 import torchvision.transforms as transforms
10 import torchvision.models as models 11 import datetime 12 13 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 14 15 16 num_steps = 10000 17 save_path = "data/drew/img/end_%s.jpg" % datetime.datetime.now().strftime("%Y%m%d%H%M%S") 18 content_img_path = "data/drew/img/dancing.jpg
" 19 style_img_path = "data/drew/img/picasso.jpg" 20 21 22 def get_img_size(img_name): 23 im = Image.open(img_name).convert('RGB') 24 return im, im.height, im.width 25 26 27 def image_loader(img, im_h, im_w): 28 loader = transforms.Compose([transforms.Resize([im_h, im_w]), transforms.ToTensor()])
29 im_l = loader(img).unsqueeze(0) 30 return im_l.to(device, torch.float) 31 32 33 c_image, c_im_h, c_im_w = get_img_size(content_img_path) 34 s_image, s_im_h, s_im_w = get_img_size(style_img_path) 35 content_img = image_loader(c_image, c_im_h, c_im_w) 36 style_img = image_loader(s_image, c_im_h, c_im_w) 37 38 39 assert style_img.size() == content_img.size(), "we need to import style and content images of the same size" 40 unloader = transforms.ToPILImage() 41 42 plt.ion() 43 44 45 def imshow(tensor, title=None): 46 image = tensor.cpu().clone() # we clone the tensor to not do changes on it 47 image = image.squeeze(0) # remove the fake batch dimension 48 image = unloader(image) 49 plt.imshow(image) 50 if title is not None: 51 plt.title(title) 52 plt.pause(0.001) # pause a bit so that plots are updated 53 54 55 # plt.figure() 56 # imshow(style_img, title='Style Image') 57 # 58 # plt.figure() 59 # imshow(content_img, title='Content Image') 60 61 62 class ContentLoss(nn.Module): 63 64 def __init__(self, target,): 65 super(ContentLoss, self).__init__() 66 self.target = target.detach() 67 68 def forward(self, input): 69 self.loss = F.mse_loss(input, self.target) 70 return input 71 72 73 def gram_matrix(input): 74 a, b, c, d = input.size() # a=batch size(=1) 75 76 features = input.view(a * b, c * d) # resise F_XL into \hat F_XL 77 78 G = torch.mm(features, features.t()) # compute the gram product 79 80 return G.div(a * b * c * d) 81 82 83 class StyleLoss(nn.Module): 84 85 def __init__(self, target_feature): 86 super(StyleLoss, self).__init__() 87 self.target = gram_matrix(target_feature).detach() 88 89 def forward(self, input): 90 G = gram_matrix(input) 91 self.loss = F.mse_loss(G, self.target) 92 return input 93 94 95 cnn = models.vgg19(pretrained=True).features.to(device).eval() 96 97 98 cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device) 99 cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device) 100 101 102 class Normalization(nn.Module): 103 def __init__(self, mean, std): 104 super(Normalization, self).__init__() 105 self.mean = mean.clone().detach().view(-1, 1, 1) 106 self.std = std.clone().detach().view(-1, 1, 1) 107 108 def forward(self, img): 109 # normalize img 110 return (img - self.mean) / self.std 111 112 113 content_layers_default = ['conv_4'] 114 style_layers_default = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5'] 115 116 117 def get_style_model_and_losses(cnn, normalization_mean, normalization_std, style_img, content_img, 118 content_layers=content_layers_default, style_layers=style_layers_default): 119 normalization = Normalization(normalization_mean, normalization_std).to(device) 120 121 content_losses = [] 122 style_losses = [] 123 124 model = nn.Sequential(normalization) 125 126 i = 0 # increment every time we see a conv 127 for layer in cnn.children(): 128 if isinstance(layer, nn.Conv2d): 129 i += 1 130 name = 'conv_{}'.format(i) 131 elif isinstance(layer, nn.ReLU): 132 name = 'relu_{}'.format(i) 133 layer = nn.ReLU(inplace=False) 134 elif isinstance(layer, nn.MaxPool2d): 135 name = 'pool_{}'.format(i) 136 elif isinstance(layer, nn.BatchNorm2d): 137 name = 'bn_{}'.format(i) 138 else: 139 raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__)) 140 141 model.add_module(name, layer) 142 143 if name in content_layers: 144 # add content loss: 145 target = model(content_img).detach() 146 content_loss = ContentLoss(target) 147 model.add_module("content_loss_{}".format(i), content_loss) 148 content_losses.append(content_loss) 149 150 if name in style_layers: 151 # add style loss: 152 target_feature = model(style_img).detach() 153 style_loss = StyleLoss(target_feature) 154 model.add_module("style_loss_{}".format(i), style_loss) 155 style_losses.append(style_loss) 156 157 # now we trim off the layers after the last content and style losses 158 for i in range(len(model) - 1, -1, -1): 159 if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss): 160 break 161 162 model = model[:(i + 1)] 163 164 return model, style_losses, content_losses 165 166 167 input_img = content_img.clone() 168 169 # plt.figure() 170 # imshow(input_img, title='Input Image') 171 172 173 def get_input_optimizer(input_img): 174 optimizer = optim.LBFGS([input_img]) 175 return optimizer 176 177 178 def run_style_transfer(cnn, normalization_mean, normalization_std, 179 content_img, style_img, input_img, num_steps=num_steps, 180 style_weight=1000000, content_weight=1): 181 """Run the style transfer.""" 182 print('Building the style transfer model..') 183 model, style_losses, content_losses = get_style_model_and_losses(cnn, 184 normalization_mean, normalization_std, style_img, content_img) 185 186 # We want to optimize the input and not the model parameters so we 187 # update all the requires_grad fields accordingly 188 input_img.requires_grad_(True) 189 model.requires_grad_(False) 190 191 optimizer = get_input_optimizer(input_img) 192 193 print('Optimizing..') 194 run = [0] 195 while run[0] <= num_steps: 196 197 def closure(): 198 # correct the values of updated input image 199 with torch.no_grad(): 200 input_img.clamp_(0, 1) 201 202 optimizer.zero_grad() 203 model(input_img) 204 style_score = 0 205 content_score = 0 206 207 for sl in style_losses: 208 style_score += sl.loss 209 for cl in content_losses: 210 content_score += cl.loss 211 212 style_score *= style_weight 213 content_score *= content_weight 214 215 loss = style_score + content_score 216 loss.backward() 217 218 run[0] += 1 219 if run[0] % 50 == 0: 220 print("run {}:".format(run)) 221 print('Style Loss : {:4f} Content Loss: {:4f}'.format( 222 style_score.item(), content_score.item())) 223 print() 224 225 return style_score + content_score 226 227 optimizer.step(closure) 228 229 # a last correction... 230 with torch.no_grad(): 231 input_img.clamp_(0, 1) 232 233 return input_img 234 235 236 begin_time = datetime.datetime.now() 237 print("******************開始時間*****************", begin_time) 238 output = run_style_transfer(cnn, cnn_normalization_mean, cnn_normalization_std, 239 content_img, style_img, input_img) 240 try: 241 plt.figure() 242 imshow(output, title='Output Image') 243 244 # sphinx_gallery_thumbnail_number = 4 245 plt.ioff() 246 plt.savefig(save_path) 247 except Exception as e: 248 print(e) 249 print("******************結束時間*****************", datetime.datetime.now()) 250 print("******************耗時*****************", datetime.datetime.now()-begin_time) 251 # plt.show()

dancing.jpg

picasso.jpg

我這遷移後的影象,還是不錯的。

風格:

內容:

遷移融合後:

有興趣的可以去研究一下原文:

原文地址:

https://pytorch.org/tutorials/advanced/neural_style_tutorial.html

原GitHub程式碼地址:

https://github.com/pytorch/tutorials/blob/master/advanced_source/neural_style_tutorial.py

需要準備:

有顯示卡並且支援pytorch訓練的伺服器,只是cpu的話就算了,GPU伺服器跑幾分鐘,cpu伺服器跑跑一小時,cpu還100%!