1. 程式人生 > >Masking GAN

Masking GAN

github程式碼:https://github.com/tgeorgy/mgan

文章的創新點:

1.生成網路輸入x,輸出包括分割模板mask,和中間影象y,根據mask將輸入x與中間影象y結合,得到生成影象.這樣得到的生成影象背景與輸入x相同,前景為生成部分.

2.採用端到端訓練,在cyclegan損失函式的基礎上,添加了對輸出生成影象進行約束.

模型結構如下,

diagram

生成網路首先輸出為分割模板mask,以及中間影象y,將中間影象y和mask混合,得到的輸出作為最後的生成生成影象.生成網路程式碼如下,

class Generator(nn.Module):
    def __init__(self
, input_nc=3, output_nc=4, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6)
: assert(n_blocks >= 0) super(Generator, self).__init__() self.input_nc = input_nc self.output_nc = output_nc self.ngf = ngf model = [nn.ReflectionPad2d(3), nn.Conv2d
(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), nn.ReLU(True)] n_downsampling = 2 for i in range(n_downsampling): mult = 2**i model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride
=2, padding=1)
, norm_layer(ngf * mult * 2), nn.ReLU(True)] mult = 2**n_downsampling for i in range(n_blocks): model += [ResnetBlock(ngf * mult, norm_layer=norm_layer, use_dropout=use_dropout)] for i in range(n_downsampling): mult = 2**(n_downsampling - i) model += [nn.ReflectionPad2d(1), nn.Conv2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=1), norm_layer(int(ngf * mult / 2)), nn.ReLU(True), nn.Conv2d(int(ngf * mult / 2), int(ngf * mult / 2)*4, kernel_size=1, stride=1), nn.PixelShuffle(2), norm_layer(int(ngf * mult / 2)), nn.ReLU(True), ] model += [nn.ReflectionPad2d(3)] model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] self.model = nn.Sequential(*model)

程式碼中,生成網路輸入通道為3,輸出通道為4,第一個通道為mask,其他三個通道為中間生成影象.

def forward(self, input):
    output = self.model(input)
    mask = F.sigmoid(output[:, :1])
    oimg = output[:, 1:]
    mask = mask.repeat(1, 3, 1, 1)
    oimg = oimg*mask + input*(1-mask)

    return oimg, mask

採用cyclegan結構,也就是,包含兩個生成網路,兩個判別網路.

對於每個生成網路,損失函式包括三個部分,第一個為loss_P2N_cyc ,與cyclegan loss相同,即輸入到生成網路g1的輸出,在輸入生成網路g2,得到輸出與輸入儘量相同.第二個loss_P2N_gan為gan損失函式,也就是判別網路判斷label為真.第三個為loss_N2P_idnt,也就是生成網路g1的輸出與label儘量相似,也就是文章是end to end(輸入-label對應)訓練,由於cyclegan不是end to end,所以沒有這個損失函式,

criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()
criterion_gan = nn.MSELoss()
# Train P2N Generator
real_pos_v = Variable(real_pos)
fake_neg, mask_neg = netP2N(real_pos_v)
rec_pos, _ = netN2P(fake_neg)
fake_neg_lbl = netDN(fake_neg)

loss_P2N_cyc = criterion_cycle(rec_pos, real_pos_v)
loss_P2N_gan = criterion_gan(fake_neg_lbl, Variable(real_lbl))
loss_N2P_idnt = criterion_identity(fake_neg, real_pos_v)
# Train N2P Generator
real_neg_v = Variable(real_neg)
fake_pos, mask_pos = netN2P(real_neg_v)
rec_neg, _ = netP2N(fake_pos)
fake_pos_lbl = netDP(fake_pos)

loss_N2P_cyc = criterion_cycle(rec_neg, real_neg_v)
loss_N2P_gan = criterion_gan(fake_pos_lbl, Variable(real_lbl))
loss_P2N_idnt = criterion_identity(fake_pos, real_neg_v)

loss_G = ((loss_P2N_gan + loss_N2P_gan)*0.5 +
          (loss_P2N_cyc + loss_N2P_cyc)*lambda_cycle +
          (loss_P2N_idnt + loss_N2P_idnt)*lambda_identity)

判別網路用於判別輸入的真假,

# Train Discriminators
netDN.zero_grad()
netDP.zero_grad()
fake_neg_score = netDN(fake_neg.detach())
loss_D = criterion_gan(fake_neg_score, Variable(fake_lbl))
fake_pos_score = netDP(fake_pos.detach())
loss_D += criterion_gan(fake_pos_score, Variable(fake_lbl))

real_neg_score = netDN.forward(real_neg_v)
loss_D += criterion_gan(real_neg_score, Variable(real_lbl))
real_pos_score = netDP.forward(real_pos_v)
loss_D += criterion_gan(real_pos_score, Variable(real_lbl))