更快更穩定:這就是Wasserstein GAN
這篇論文介紹了一種名叫 Wasserstein GAN(WGAN)的全新演算法,這是一種可替代標準生成對抗網路(GAN)的訓練方法。這項研究沒有應用傳統 GAN 所用的那種 minimax 形式,而是基於一種名為“Wasserstein 距離”的新型距離指標做了某些修改。
這是基於 MLP 生成器的 WGAN(左上圖)和 GAN(右上圖)生成的樣本,很顯然,這裡 WGAN 的影象質量優於標準 GAN。
簡單來說,WGAN 有兩個改變。第一個是取出了判別器中的 sigmoid,這是用於計算輸出均值之間的差異的。第二個改變是判別器(這篇論文稱之為 Critic),這就只是一個函式,其目標是讓假資料有較低的預期值,讓真實資料有較高的預期值。注意這些輸出不再是對數概率,這樣這些損失現在就與二元交叉熵無關了。
Wasserstein GAN
近期一些 GAN 論文提出了一些不同的生成對抗訓練架構。但是,這些架構的一個共同點是 f-距離(包括 KL-距離、總變差散度(total variation divergence))。f-距離是真實資料分佈和生成資料分佈之間的密度比 P_r(x)/P_θ(x) 的函式,非常類似於 Jenson-Shannon(JS)距離。
上式是標準 GAN 的目標。在 GAN 的訓練過程中,判別器的目標是最大化上述目標(最大值為 0,最小值為負無窮)。GAN 的估計可對應於 JS 距離度量。我們再看看 f-距離。如果兩個分佈沒有顯著的重疊,我們又能做什麼?如果不能,那麼其概率密度比將為零或無窮,而且其對整體概率估計(比如由 (0, z) 點組成的真實資料,其中 z ~ U (0,1))會有巨大的負面影響,於是樣本就會從 y=0 到 y=1 沿垂直軸 x=0 均勻分佈。但如果該模型生成樣本 (θ, z),則其分佈根本不會重疊。在這種情況下,會發生梯度消失問題,會使標準 GAN 崩潰。
所以基於這一事實,這篇論文的作者提出使用 Wasserstein 距離,而不是 JS 距離。Wasserstein 距離定義為:
我們可以這樣解讀這一等式:首先,所有可能的配置都會被選取,假設是 P_r(x) 和 P_g(x)。然後這些點會根據這兩個分佈來配對。在那之後,它會計算每組配置中配對的平均距離。這裡的 inf 可以被視為最小值,這樣最後它將從所有可能的配對配置中選擇出最小的平均距離。這篇論文提出使用這一距離度量來替代 f-距離,這樣它就不再是密度比的函式的。通過這種方式,即使兩個分佈沒有重疊,Wasserstein 距離也仍然可以描述它們相距多遠,並且通過這種方式能從根本上解決梯度消失問題。
由於初始的 Wasserstein 距離定義具有難以解決的計算複雜性,所以研究者使用了一種替代定義:
這會導致 Kantorovich-Rubinstein二元性。
值得注意的是,當且僅當 f(x) 的梯度的幅度由 K 在該空間的所有部分設定了上界時,f(x) 是 K-Lipschitz。這篇論文通過將權重限制在一定範圍內,使用網路來近似建模 K-Lipschitz。這裡的上界可以被視為是一個最大值(二元表示式)。理論上,其目標是尋找到一個 critic 函式,以最大化真實樣本均值和偽造樣本均值之間餘量。
WGAN 演算法
上面描述了 Wasserstein 生成對抗網路(WGAN)演算法。經過前面的知識介紹之後,這個演算法看起來就更簡單一些了。總結如下:
- 更新 Critic n 次迭代,之後更新生成器;
- 對於 Critic 的每次迭代,基於 Wasserstein 距離更新梯度,然後剪下權重;
- 使用 RMSProp;
- 像普通 GAN 那樣更新生成器。
下面給出了實現 WGAN 演算法的程式碼示例:
# (1) update Critic Network
for p in netD.parameters():
p.requires_grad = True
netD.zero_grad()
# train with real
real_cpu, _ = data
netD.zero_grad()
batch_size = real_cpu.size(0)
input.data.resize_(real_cpu.size()).copy_(real_cpu)
errD_real = netD(input)
errD_real.backward(one)
# train with fake
noise.data.resize_(batch_size, nz, 1, 1)
noise.data.normal_(0, 1)
fake = netG(noise)
input.data.copy_(fake.data)
errD_fake = netD(input)
errD_fake.backward(mone)
errD = errD_real - errD_fake
optimizerD.step()
# (2) Update G network
for p in netD.parameters():
p.requires_grad = False # to avoid computation
netG.zero_grad()
noise.data.resize_(opt.batchSize, nz, 1, 1)
noise.data.normal_(0, 1)
fake = netG(noise)
errG = netD(fake)
errG.backward(one)
optimizerG.step()
實證實驗
研究者使用 Wasserstein GAN 進行了一些定量實驗,並且表明相比於標準 GAN,使用 WGAN有顯著的實際好處。
他們提到了兩個優勢:
- WGAN 的損失表現出了收斂的特性。
如上所示,上圖為 WGAN,下圖為標準 GAN。對於 WGAN,隨著損失快速下降,樣本質量也會增長。相比於 WGAN,標準 GAN 演算法的誤差曲線是不穩定的,甚至會增大。
- 優化過程的穩定性提升。
上圖是使用無批歸一化的該演算法得到的生成器的結果。左上基於 WGAN 演算法,右上基於標準 GAN 演算法。標準 GAN 不能學習的地方,WGAN 依然能穩定地生成合理的樣本。
分析師簡評
這篇論文提出了一種名為 Wasserstein GAN 的新型生成對抗網路。它從理論上向我們說明了已有的 GAN 模型失敗的原因以及 WGAN 有效的原因。相比於 DCGAN 等標準 GAN,這篇論文表明即使沒有批歸一化,WGAN 也能穩定地訓練。但也仍然存在一些值得關注的地方。首先,在更新生成器之前他們更新了 critic n 次迭代,這意味著 critic 的迭代次數仍是人工調節的。是否存在優化兩者的更好方法呢?第二,WGAN 在非常深度的網路上的泛化情況如何,比如 152 層的殘差網路?第三,他們限制了權重的範圍以確保 Lipschitz 連續性,但是否存在建模這種情況的方法?最後,生成對抗訓練能否用於詞預測等 NLP 任務,同時還能保持穩定性?