1. 程式人生 > >收斂速度更快更穩定的Wasserstein GAN(WGAN)

收斂速度更快更穩定的Wasserstein GAN(WGAN)

生成對抗網路(GANs)是一種很有力的生成模型,它解決生成建模問題的方式就像在兩個對抗式網路中進行比賽:給出一些噪聲源,生成器網路能夠產生合成的資料,鑑別器網路在真實資料和生成器的輸出中進行鑑別。GAN能夠產生十分生動的樣例,但是很難訓練。儘管最近大量的研究工作都投入到了尋找能讓GAN穩定訓練的方法上,但GAN的持續穩定訓練成了依然是一個公開的問題。
概要最近提出的Wasserstein GAN(WGAN)在GAN的穩定訓練上取得了重大進展,但是依然會產生低質量的樣例或者出現在某些設定上不能收斂的情況。會產生這些訓練失敗的樣例一般都是因為,作者通過在WGAN中使用權重修剪來達到在鑑別器中強制利普希茨(Lipschitz)限制條件的目的,但這樣的方式會導致病態行為。
根據上述情況,文章提出了一種替代方法來強制Lipschitz限制條件:不修剪權重,而是根據輸入來懲罰鑑別器的梯度正則項。這一方法方法與權重修剪WGAN相比,收斂的更快並且能產生更高質量的樣例。這一替代方法基本能夠保證很穩定的GAN訓練。這是自深度學習發展以來,第一次可以訓練多種多樣的GAN結構,並且幾乎不用進行超引數的調整,包括101層的ResNet和離散資料上的語言模型。
本文將從提出背景,演算法介紹,實驗結果,論文實現等四個方面來介紹這篇文章。
提出背景作者發現,WGAN的權重修剪會導致優化困難,並且即使能夠優化成功,鑑別器也會出現病態等值面。作者測試了了WGAN中的權重限制條件(對每個權重的量級進行硬修剪),也測試了其他的權重限制條件(L2範數修剪,權重正則化等等),以及一些軟限制條件(L1和L2權重衰減等等),實驗結果發現他們都會出現類似的問題。
容量使用不當
作者在玩具資料庫上使用權重修剪來訓練WGAN鑑別器使其達到最優,保持生成器的分佈固定,以及單元變數高斯噪聲,對鑑別器的等值面曲線圖進行了繪製。我們省略了鑑別器中的批正則化(batch normalization)。在每個樣例中,作者發現,用權重修剪訓練的鑑別器忽略了資料分佈的高階矩,而是對很簡單的近似建模來進行優化函式。相比而言,梯度懲罰的方法不會因為這樣的行為受到影響。


梯度消失和爆炸作者發現,WGAN的優化過程之所以很難,是由於權重修剪和損失函式之間的互動。這一情況會不可避免的導致梯度消失或梯度爆炸,取決於修剪的閾值。如果權重由於限制變得太小,梯度就會在反向傳播到之前的層時消失。這會阻止鑑別器(和生成器)之前的層接受有用的訓練訊號,並且會使得深度網路學習速度變得很慢。


(a)深度WGAN的鑑別器的梯度正則項在玩具資料庫上訓練時的變化情況。用權重修剪的WGAN中的梯度總是爆炸或消失,而作者提出的梯度懲罰方法則為之前層提供了穩定梯度。
(b)分別使用權重修剪(左)和使用梯度懲罰(右)的WGAN的權重直方圖。權重修剪將權重推到了修剪範圍的極限,當這個範圍很高時,就會導致梯度爆炸,然後減慢訓練速度。
演算法——梯度懲罰由於WGAN中的權重修剪帶來的不好的結果,作者提出了一種替代方法,在訓練目標上加強Lipschitz限制條件:當且僅當一個可微函式梯度的正則項處處小於等於1時,它才滿足1-Lipschitz條件。所以作者直接根據輸入來限制鑑別器的梯度正則項,也就是梯度懲罰。
分以下幾個步驟實現:
·根據直線取樣。
·超引數:梯度懲罰引入了一個引數λ,實驗中設為10。
·去掉鑑別器的批正則化。
·使用Adam引數設定。
·雙面懲罰。
·二次懲罰。
實驗結果與權重修剪WGAN相比,梯度懲罰WGAN不僅提高了收斂速度,還提升了網路訓練的穩定性。
CIFAR-10 訓練速度和樣例質量


四個模型在CIFAR-10資料集上的測試分數與生成器迭代次數(左)和系統時間(右)的曲線圖。四個模型分別為:權重修剪WGAN,梯度懲罰以及RMSProp(控制優化器)WGAN,梯度懲罰以及Adam引數設定WGAN和DCGAN。從圖中可以看出,即使在同樣的學習速率下,梯度懲罰方法的表現比權重修剪有顯著的提高。DCGAN收斂的更快,但是使用梯度懲罰的WGAN達到相似的分數時,穩定性提高了。
LSUN臥室資料集作者在LSUN臥室資料集上訓練了多種GAN模型。除了作為基準的DCGAN,還選擇了六個比較難訓練的結構:
生成器沒有批正則化以及的連續數量濾波
1.4層的512維ReLU MLP生成器
2.生成器和鑑別器中都沒有正則化
3.門控相乘非線性
4.Tanh非線性
5.101層ResNet 生成器和鑑別器


從圖中可以看出,使用不同方法訓練的GAN結構,只有作者提出的使用梯度懲罰WGAN方法在每個結構的訓練中都成功了。
字元級語言建模


上部:在Billion Word資料集上測試的WGAN字元級語言模型樣例,縮短到32個字元。
底部:使用標準GAN訓練的同樣結構的模型樣例。
據我們所知,這是第一個完全用對抗式訓練的語言生成模型,而不需要相似性最大化損失的監督。從圖中可以看出它會產生很多拼寫錯誤,但是依然能成功的學習到自然語言統計學的很多資料。
損失曲線和過擬合檢測


(a)作者的模型在LSUN臥室資料集上的鑑別器損失,隨著網路的訓練收斂到最小值。
(b)在1000位MNIST子資料集上的WGAN訓練和驗證損失。可以看出使用我們的方法(左)或權重修剪法(右)都會產生過擬合。我們的方法中,鑑別器比生成器過擬合速度更快,是的訓練損失隨時間逐漸增加,而驗證損失隨之減少。
論文地址及實現
論文連結:https://arxiv.org/abs/1704.00028
github連結:https://github.com/igul222/improved_wgan_training
環境要求:
·Python, NumPy, TensorFlow,SciPy, Matplotlib
·NVIDIA GPU
模型
所有模型的配置在檔案最上面的一列常數中已經進行了特別說明。有兩個模型應該可以直接使用:
·python gan_toy.py: 玩具資料庫(8 Gaussians, 25 Gaussians, SwissRoll).
·python gan_mnist.py: MNIST手寫數字識別庫
對於其他的模型,在執行之前請修改指令碼,指定資料集的路徑為DATA_DIR下。每個模型的資料集現在都已公開;下載連結已包含在檔案中。
·python gan_64x64.py: 64x64 結構 (論文中,該模型是在ImageNet上進行訓練的,而不是LSUN臥室圖片庫)
·python gan_language.py: 字元級別的語言模型
·python gan_cifar.py: CIFAR-10資料庫
AIJob社是《全球人工智慧》旗下專門為AI開發工程師免費服務的求職平臺。我們將竭盡全力幫助每一個ai工程師對接自己喜歡的企業,推薦給你喜歡的直接領導,幫你談一個最好的薪資待遇。
微信諮詢:aihr007簡歷投遞:
[email protected]
企業合作:[email protected]
《全球人工智慧》招聘5名兼職翻譯及10名兼職VIP社群專家:影象技術、語音技術、自然語言、機器學習、資料探勘等專業技術領域,工作內容及待遇請在公眾號內回覆“兼職+個人微訊號”聯絡工作人員。
熱門文章推薦
重磅|全球AI報告:看看谷歌|蘋果|Facebook等十幾家巨頭都收購了哪些牛逼的AI公司?