Masking GAN
阿新 • • 發佈:2018-11-06
github程式碼:https://github.com/tgeorgy/mgan
文章的創新點:
1.生成網路輸入x,輸出包括分割模板mask,和中間影象y,根據mask將輸入x與中間影象y結合,得到生成影象.這樣得到的生成影象背景與輸入x相同,前景為生成部分.
2.採用端到端訓練,在cyclegan損失函式的基礎上,添加了對輸出生成影象進行約束.
模型結構如下,
生成網路首先輸出為分割模板mask,以及中間影象y,將中間影象y和mask混合,得到的輸出作為最後的生成生成影象.生成網路程式碼如下,
class Generator(nn.Module):
def __init__(self , input_nc=3, output_nc=4, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6):
assert(n_blocks >= 0)
super(Generator, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
self.ngf = ngf
model = [nn.ReflectionPad2d(3),
nn.Conv2d (input_nc, ngf, kernel_size=7, padding=0),
norm_layer(ngf),
nn.ReLU(True)]
n_downsampling = 2
for i in range(n_downsampling):
mult = 2**i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
stride =2, padding=1),
norm_layer(ngf * mult * 2),
nn.ReLU(True)]
mult = 2**n_downsampling
for i in range(n_blocks):
model += [ResnetBlock(ngf * mult, norm_layer=norm_layer, use_dropout=use_dropout)]
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
model += [nn.ReflectionPad2d(1),
nn.Conv2d(ngf * mult, int(ngf * mult / 2),
kernel_size=3, stride=1),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True),
nn.Conv2d(int(ngf * mult / 2), int(ngf * mult / 2)*4,
kernel_size=1, stride=1),
nn.PixelShuffle(2),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True),
]
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
self.model = nn.Sequential(*model)
程式碼中,生成網路輸入通道為3,輸出通道為4,第一個通道為mask,其他三個通道為中間生成影象.
def forward(self, input):
output = self.model(input)
mask = F.sigmoid(output[:, :1])
oimg = output[:, 1:]
mask = mask.repeat(1, 3, 1, 1)
oimg = oimg*mask + input*(1-mask)
return oimg, mask
採用cyclegan結構,也就是,包含兩個生成網路,兩個判別網路.
對於每個生成網路,損失函式包括三個部分,第一個為loss_P2N_cyc ,與cyclegan loss相同,即輸入到生成網路g1的輸出,在輸入生成網路g2,得到輸出與輸入儘量相同.第二個loss_P2N_gan為gan損失函式,也就是判別網路判斷label為真.第三個為loss_N2P_idnt,也就是生成網路g1的輸出與label儘量相似,也就是文章是end to end(輸入-label對應)訓練,由於cyclegan不是end to end,所以沒有這個損失函式,
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()
criterion_gan = nn.MSELoss()
# Train P2N Generator
real_pos_v = Variable(real_pos)
fake_neg, mask_neg = netP2N(real_pos_v)
rec_pos, _ = netN2P(fake_neg)
fake_neg_lbl = netDN(fake_neg)
loss_P2N_cyc = criterion_cycle(rec_pos, real_pos_v)
loss_P2N_gan = criterion_gan(fake_neg_lbl, Variable(real_lbl))
loss_N2P_idnt = criterion_identity(fake_neg, real_pos_v)
# Train N2P Generator
real_neg_v = Variable(real_neg)
fake_pos, mask_pos = netN2P(real_neg_v)
rec_neg, _ = netP2N(fake_pos)
fake_pos_lbl = netDP(fake_pos)
loss_N2P_cyc = criterion_cycle(rec_neg, real_neg_v)
loss_N2P_gan = criterion_gan(fake_pos_lbl, Variable(real_lbl))
loss_P2N_idnt = criterion_identity(fake_pos, real_neg_v)
loss_G = ((loss_P2N_gan + loss_N2P_gan)*0.5 +
(loss_P2N_cyc + loss_N2P_cyc)*lambda_cycle +
(loss_P2N_idnt + loss_N2P_idnt)*lambda_identity)
判別網路用於判別輸入的真假,
# Train Discriminators
netDN.zero_grad()
netDP.zero_grad()
fake_neg_score = netDN(fake_neg.detach())
loss_D = criterion_gan(fake_neg_score, Variable(fake_lbl))
fake_pos_score = netDP(fake_pos.detach())
loss_D += criterion_gan(fake_pos_score, Variable(fake_lbl))
real_neg_score = netDN.forward(real_neg_v)
loss_D += criterion_gan(real_neg_score, Variable(real_lbl))
real_pos_score = netDP.forward(real_pos_v)
loss_D += criterion_gan(real_pos_score, Variable(real_lbl))