1. 程式人生 > 其它 >風格遷移訓練實踐

風格遷移訓練實踐

前一篇文章分享了Pytorch簡單風格遷移的程式碼,本著不跑掛伺服器不死心的態度,不停的增加計算步驟,看看圖片融合生成的效果,

為了方便一次性執行,把程式碼簡單改造了一下,與前一篇文章大同小異:

  1 import torch
  2 import torch.nn as nn
  3 import torch.nn.functional as F
  4 import torch.optim as optim
  5 
  6 from PIL import Image
  7 import matplotlib.pyplot as plt
  8 
  9 import torchvision.transforms as transforms
10 import torchvision.models as models 11 import datetime 12 13 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 14 15 16 def get_img_size(img_name): 17 """ 18 獲取影象大小 19 :param img_name: 20 :return: 21 """ 22 im = Image.open(img_name).convert('
RGB') 23 return im, im.height, im.width 24 25 26 def image_loader(img, im_h, im_w): 27 """ 28 載入影象 29 :param img: 30 :param im_h: 31 :param im_w: 32 :return: 33 """ 34 35 # loader = transforms.Compose([transforms.Resize([im_h, im_w]), transforms.ToTensor()])
36 loader = transforms.Compose([transforms.Resize([1000, 1000]), transforms.ToTensor()]) 37 im_l = loader(img).unsqueeze(0) 38 return im_l.to(device, torch.float) 39 40 41 def im_show(tensor, save_file_path): 42 """ 43 顯示儲存圖片 44 :param tensor: 45 :param save_file_path: 46 :return: 47 """ 48 image = tensor.cpu().clone() 49 image = image.squeeze(0) 50 image = transforms.ToPILImage()(image) 51 plt.imshow(image, aspect='equal') 52 plt.axis('off') 53 plt.savefig(save_file_path, bbox_inches='tight', pad_inches=0.0) 54 plt.pause(0.001) 55 56 57 class ContentLoss(nn.Module): 58 """ 59 內容損失 60 """ 61 62 def __init__(self, target,): 63 super(ContentLoss, self).__init__() 64 self.target = target.detach() 65 66 def forward(self, cl_input): 67 self.loss = F.mse_loss(cl_input, self.target) 68 return cl_input 69 70 71 def gram_matrix(gm_input): 72 """ 73 風格損失矩陣 74 :param gm_input: 75 :return: 76 """ 77 a, b, c, d = gm_input.size() 78 features = gm_input.view(a * b, c * d) 79 G = torch.mm(features, features.t()) 80 81 return G.div(a * b * c * d) 82 83 84 class StyleLoss(nn.Module): 85 """ 86 風格損失 87 """ 88 89 def __init__(self, target_feature): 90 super(StyleLoss, self).__init__() 91 self.target = gram_matrix(target_feature).detach() 92 93 def forward(self, fw_input): 94 G = gram_matrix(fw_input) 95 self.loss = F.mse_loss(G, self.target) 96 return fw_input 97 98 99 # 使用19層的VGG神經網路模型 100 cnn = models.vgg19(pretrained=True).features.to(device).eval() 101 102 103 cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device) 104 cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device) 105 106 107 class Normalization(nn.Module): 108 """ 109 規範化輸入影象 110 """ 111 def __init__(self, mean, std): 112 super(Normalization, self).__init__() 113 self.mean = mean.clone().detach().view(-1, 1, 1) 114 self.std = std.clone().detach().view(-1, 1, 1) 115 116 def forward(self, img): 117 return (img - self.mean) / self.std 118 119 120 def get_style_model_and_losses(cn, normalization_mean, normalization_std, style_i, content_i, cld, sld): 121 """ 122 獲取內容損失和風格損失 123 :param cn: 124 :param normalization_mean: 125 :param normalization_std: 126 :param style_i: 127 :param content_i: 128 :param cld: 129 :param sld: 130 :return: 131 """ 132 133 normalization = Normalization(normalization_mean, normalization_std).to(device) 134 content_losses = [] 135 style_losses = [] 136 137 model = nn.Sequential(normalization) 138 139 i = 0 140 for layer in cn.children(): 141 if isinstance(layer, nn.Conv2d): 142 i += 1 143 name = 'conv_{}'.format(i) 144 elif isinstance(layer, nn.ReLU): 145 name = 'relu_{}'.format(i) 146 layer = nn.ReLU(inplace=False) 147 elif isinstance(layer, nn.MaxPool2d): 148 name = 'pool_{}'.format(i) 149 elif isinstance(layer, nn.BatchNorm2d): 150 name = 'bn_{}'.format(i) 151 else: 152 raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__)) 153 154 model.add_module(name, layer) 155 156 if name in cld: 157 target = model(content_i).detach() 158 content_loss = ContentLoss(target) 159 model.add_module("content_loss_{}".format(i), content_loss) 160 content_losses.append(content_loss) 161 162 if name in sld: 163 target_feature = model(style_i).detach() 164 style_loss = StyleLoss(target_feature) 165 model.add_module("style_loss_{}".format(i), style_loss) 166 style_losses.append(style_loss) 167 168 for i in range(len(model) - 1, -1, -1): 169 if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss): 170 break 171 172 model = model[:(i + 1)] 173 174 return model, style_losses, content_losses 175 176 177 def get_input_optimizer(input_i): 178 """ 179 使用 L-BFGS 演算法 180 最小化風格、內容的損失 181 :param input_i: 182 :return: 183 """ 184 optimizer = optim.LBFGS([input_i]) 185 return optimizer 186 187 188 def run_style_transfer(cn, norma_mean, normalization_std, ct_img, sl_img, in_img, steps, style_weight, content_weight): 189 """ 190 樣式轉換,建立風格遷移模型 191 :param cn: 192 :param norma_mean: 193 :param normalization_std: 194 :param ct_img: 195 :param sl_img: 196 :param in_img: 197 :param steps: 198 :param style_weight: 199 :param content_weight: 200 :return: 201 """ 202 content_layers = ['conv_4'] 203 style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5'] 204 model, style_losses, content_losses = get_style_model_and_losses(cn, norma_mean, normalization_std, sl_img, ct_img, content_layers, style_layers) 205 in_img.requires_grad_(True) 206 model.requires_grad_(False) 207 208 optimizer = get_input_optimizer(in_img) 209 print('Optimizing..') 210 run = [0] 211 while run[0] <= steps: 212 213 def closure(): 214 with torch.no_grad(): 215 in_img.clamp_(0, 1) 216 217 optimizer.zero_grad() 218 model(in_img) 219 style_score = 0 220 content_score = 0 221 222 for sl in style_losses: 223 style_score += sl.loss 224 for cl in content_losses: 225 content_score += cl.loss 226 227 style_score *= style_weight 228 content_score *= content_weight 229 230 loss = style_score + content_score 231 loss.backward() 232 233 run[0] += 1 234 if run[0] % 50 == 0: 235 print("run {}:".format(run)) 236 print('Style Loss : {:4f} Content Loss: {:4f}'.format(style_score.item(), content_score.item())) 237 return style_score + content_score 238 239 optimizer.step(closure) 240 with torch.no_grad(): 241 in_img.clamp_(0, 1) 242 return in_img 243 244 245 def style_transfer(content_image_path, style_image_path, image_save_path, run_steps): 246 """ 247 風格遷移主入口 248 :param content_image_path: 內容圖片 249 :param style_image_path: 風格圖片 250 :param image_save_path: 儲存圖片地址 251 :param run_steps: 執行計算次數 252 :return: 253 """ 254 c_image, c_im_h, c_im_w = get_img_size(content_image_path) 255 s_image, s_im_h, s_im_w = get_img_size(style_image_path) 256 content_img = image_loader(c_image, c_im_h, c_im_w) 257 style_img = image_loader(s_image, c_im_h, c_im_w) 258 assert style_img.size() == content_img.size() 259 # 輸入內容影象 260 input_img = content_img.clone() 261 begin_time = datetime.datetime.now() 262 print("******************開始時間*****************", begin_time) 263 output = run_style_transfer(cnn, cnn_normalization_mean, cnn_normalization_std, content_img, style_img, input_img, run_steps, s_weight, c_weight) 264 try: 265 im_show(output, image_save_path) 266 except Exception as e: 267 print(e) 268 print("******************結束時間*****************", datetime.datetime.now()) 269 print("******************耗時*****************", datetime.datetime.now() - begin_time) 270 271 272 if __name__ == '__main__': 273 s_weight = 1000000 274 c_weight = 1 275 # content_img_path = "data/drew/img/512.png" 276 content_img_path = "/data/drew/img/dancing.jpg" 277 # style_img_path = "data/drew/img/512r.png" 278 style_img_path = "/data/drew/img/picasso.jpg" 279 for steps in range(100, 3200, 200): 280 # save_path = "data/drew/img/end_%s_%s.jpg" % (steps, datetime.datetime.now().strftime("%Y%m%d%H%M%S")) 281 save_path = "/data/drew/img/end_%s_%s.jpg" % (steps, datetime.datetime.now().strftime("%Y%m%d%H%M%S")) 282 style_transfer(content_img_path, style_img_path, save_path, steps)
View Code