乾貨丨一文看懂生成對抗網路:從架構到訓練技巧
文章來源:機器之心
論文地址:https://arxiv.org/pdf/1710.07035.pdf
生成對抗網路(GAN)提供了一種不需要大量標註訓練資料就能學習深度表徵的方式。它們通過反向傳播演算法分別更新兩個網路以執行競爭性學習而達到訓練目的。GAN 學習的表徵可用於多種應用,包括影象合成、語義影象編輯、風格遷移、影象超解析度技術和分類。這篇論文的目的是為訊號處理社群做一個 GAN 的綜述。除了介紹訓練和構建 GAN 的不同方法以外,我們還討論了其理論和應用中尚存在的挑戰。
1. 引言
生成器和判別器通常由包含卷積和(或)全連線層的多層網路構成。生成器和判別器必須是可微的,但並不必要是直接可逆的(理論分析上必須可逆)。如果生成網路由某些表徵空間對映得到,即對映一個隱空間到資料空間中(我們需要把注意力集中在影象上),用更加形式化的表示為:G : G(z) → R^|x|,其中 z ∈ R^|z|是隱空間中的一個樣本,x ∈ R^|x|是一張影象,| · |表示維數。
原始 GAN 的判別網路 D 可以看成是將影象資料對映到(該影象是來自真實資料分佈,而不是生成器分佈)判別概率的函式 D : D(x) → (0, 1)。對於一個固定的生成器 G,判別器 D 可能被訓練用於分辨影象是來自訓練資料(真,概率接近 1)還是來自生成器(假,概率接近 0)。若判別器已經是最好的,它將變得無法被欺騙,而這時生成器 G 需要繼續訓練以降低判別器的準確率。如果生成器分佈足以完美匹配真實資料分佈,那麼判別器將會被最大地迷惑而對所有輸入給出 0.5 的概率值。在實踐中,判別器可能無法訓練到理想狀態,我們將在第 4 節更深入地探索訓練過程。
圖 1. 這張圖分別展示了 GAN 中的判別器 D 和生成器 G 的訓練流程。通常利用神經網路實現,但實際上它們可以用任意形式的(將資料從一個空間對映到另一個的)可微系統實現。
3.GAN 架構
圖 2. 在 GAN 的訓練過程中,生成器可以生成一個樣本分佈 p_g(x)以匹配真實資料 p_data(x)。通過適當引數化並訓練 GAN,這些分佈將變得難以區分。GAN 的表徵分佈可以通過生成網路和判別網路的學習引數(權重)而構建。
A. 全連線 GAN
首個 GAN 架構在生成器與鑑別器上皆使用全連線神經網路。這種架構型別被應用於相對簡單的影象資料庫,即 MNIST(手寫數字)、CIFAR-10(自然影象)和多倫多人臉資料集(TFD)。
B. 卷積 GAN
因為 CNN 非常適合處理影象資料,所以從全連線到卷積神經網路是一個自然的擴充套件。早期在 CIFAR-10 上進行的試驗表明,用監督學習相同水平和表徵能力的 CNN 訓練生成器和鑑別器網路會更加困難。
通過使用多尺度分解生成過程的形式,拉普拉斯金字塔形對抗網路(LAPGAN)[13] 為這一問題貢獻了一個方案:真值影象本身被分解成拉普拉斯金字塔(Laplacian pyramid),並且條件性卷積 GAN 被訓練在給定上一層的情況下生成每一層。
此外,Radford et al.[5](為「深度卷積 GAN」)提出了一種稱之為 DCGAN 的網路架構族,它允許訓練一對深度卷積生成器和判別器網路。DCGAN 在訓練中使用帶步長的卷積(strided convolution)和小步長卷積(fractionally-strided convolution),並在訓練中學習空間下采樣和上取樣運算元。這些運算元處理取樣率和位置的變化,這是影象空間到低維潛在空間、從影象空間到判別器對映的重要要求。Section IV-B 將詳細介紹 DCGAN 架構和訓練。
作為 2D 影象合成的擴充套件,Wu et al. [14] 展示了使用體積卷積(volumetric convolution)合成 3D 資料樣本的 GAN。Wu et al. [14] 合成包括新物體,如椅子、桌子和汽車;此外,他們還展示了一種從 2D 影象對映到 3D 影象的方法。
C. 條件 GAN
Mirza 等人通過將生成器和判別器改造成條件類(class-conditional)而將(2D)GAN 框架擴充套件成條件設定。條件 GNN 的優勢在於可以對多形式的資料生成提供更好的表徵。條件 GAN 和 InfoGAN[16] 是平行的,它可以將噪聲源分解為不可壓縮源和一個「隱編碼」(latent code),並可以通過最大化隱編碼和生成器之間的互動資訊而發現變化的隱藏因子。這個隱編碼可用於在完全無監督的資料中發現目標類,即使這個隱編碼是不明確的。由 InfoGAN 學到的表徵看起來像是具備語義特徵的,可以處理圖貌中的複雜糾纏因素(包括姿勢變化、光照和麵部影象的情緒內容)。
D.GAN 推斷模型
GAN 的初始形式無法將給定的輸入 x 對映為隱空間中的向量(在 GAN 的文獻中,這通常被稱為一種推斷機制)。人們提出了幾種反轉預訓練 GAN 的生成器的技術,比如各自獨立提出的對抗性學習推斷(Adversarially Learned Inference,ALI)和雙向 GAN(Bidirectional GANs),它們能提供簡單而有效的擴充套件,通過加入一個推斷網路,使判別器共同測試資料空間和隱空間。
這種形式下的生成器由兩個網路組成:即編碼器(推斷網路)和解碼器。它們同時被訓練用於欺騙判別器。而判別器將接收到一個向量對(x,z)(見圖 4 所示),並決定其是否包含一個真實影象以及其編碼,或者一個生成的影象樣本以及相關的生成器的隱空間輸入。
理想情況下,在一個編碼-解碼模型中的輸出(作為輸入的重構)應該和輸入相似。通常情況下,使用 ALI/BiGAN 合成的重構資料樣本的保真度都很低。通過額外新增(資料樣本和它們的重構的)對抗代價函式可以提高樣本的保真度。
E. 對抗自編碼器(AAE)
自編碼器是由編碼器和解碼器組成的網路,學習將資料對映到內部隱表徵中,再映射出來,即從資料空間中學習將影象(或其它)通過編碼對映到隱空間中,再通過解碼從隱空間映射回資料空間。這兩個對映形成了一種重構運算,而這兩個對映將被訓練直到重構影象儘可能的接近初始影象。
圖 3. 圖左,條件 GAN(Conditional GAN),由 Mirza 等人提出,可以根據條件(class-condition)合成影象;判別器根據條件類判定影象為真或假。圖右,InfoGAN,其判別器還可以估計類別標籤。
圖 4. ALI/BiGAN 的結構,由三個網路構成。其中一個是判別器,還有一個網路將噪聲向量從隱空間對映到影象空間(解碼器,以 G 表示),最後一個網路(編碼器,以 E 表示)將真實影象從影象空間對映到隱空間。
4. 訓練 GAN
A. 引言
GAN 的訓練目標是獲得令判別器分類準確率最大化的引數,以及獲得最大化欺騙判別器的生成器引數。訓練過程已在圖 5 中總結。
訓練的代價由一個價值函式 V(G,D) 評估,其包含了生成器和判別器的引數。
訓練過程可表示如下:
訓練過程中,其中一個模型的引數被更新,同時另一個模型的引數固定不變。Goodfellow 等人證明對於一個固定的生成器存在唯一的最優判別器 D∗ (x) = p_data(x) /(p_data(x)+p_g(x))。他們還證明了生成器 G 在 pg(x) = pdata(x) 的時候是最優的,這等價於最優判別器對所有的樣本 x 都給出 0.5 的概率值。換個說法就是當判別器 D 被最大化迷惑而無法區分真假樣本的時候,生成器 G 是最優的。
理想情況下,判別器將被訓練到對當前生成器為最優;然後生成器再次被更新。然而實際情況下,判別器可能無法訓練到最優,因此常用的策略是,判別器只會經過少量的迭代訓練,而且生成器和判別器同步更新。此外,生成器通常使用一種交替的、非飽和的訓練規範,即使用 max_G log D(G(z)),而不使用 min_G log(1 − D(G(z)))。
儘管理論上存在唯一的解,GAN 的訓練還是存在著挑戰,經常出現不穩定的狀況,原因有好幾個。一種提升 GAN 訓練效率的辦法是追蹤訓練過程中可能遇到的經驗症狀,這包括:
-
兩個模型(生成器、判別器)無法收斂 [5];
-
生成器「崩潰」,即用不同的輸入生成相似的樣本 [25];
-
判別器的損失很快收斂為零,從而沒有足夠強的梯度路徑可以繼續更新生成器。
圖 5. GAN 訓練的主要迴圈。新型資料樣本 x' 可以通過傳遞隨機樣本得出,z 為抽取的隨機噪聲並作為生成器網路的輸入。生成器更新一次之前,判別器可能需要更新 k 次。
B. 訓練技巧
用於影象生成的 GAN 訓練的第一個重大改進是 Radford et al. [5] 提出的 DCGAN 架構。該研究是計算機視覺領域之前使用的 CNN 架構的進一步探索,並且帶來了一系列構建和訓練生成器、判別器的準則。在 Section III-B 中,我們提到帶步長的卷積和小步長卷積的重要性 [27],它們是架構設計的重要元件。這允許生成器和判別器學習優秀的上取樣和下采樣操作,這些操作可能提升影象合成的質量。具體到訓練中,研究者推薦在兩種網路中使用批量歸一化 [28],以穩定深層模型中的訓練。另一個建議是最小化用於提升深層模型訓練可行性的全連線層的數量。最後,Radford et al. [5] 認為在判別器中間層使用 leaky ReLU 啟用函式的效能優於使用常規的 ReLU 函式。
Salimans et al. [25] 進一步提出穩定 GAN 訓練的啟發式方法。首先,特徵匹配稍稍改變生成器的目標,以增加可獲取的資訊量。具體來說,判別器仍然需要訓練來區分真樣本和假樣本,但是生成器的訓練目標是匹配判別器中假樣本和真樣本的期望中間啟用值(特徵)。第二,小批量判別(mini-batch discrimination)向判別器額外新增輸入,該特徵對小批量中的給定樣本和其他樣本的距離進行編碼。目的在於防止模式崩潰(mode collapse),因為判別器能夠輕易判斷生成器是否生成同樣的輸出。
第三個啟發式技巧是啟發式平均(heuristic averaging),如果網路引數偏離之前值的執行平均值,則會受到懲罰,這有助於收斂到平衡態。第四個技巧是虛擬批量歸一化(virtual batch normalization),它可減少小批量內樣本對其他樣本的依賴性,方法是使用訓練開始就確定的固定參考小批量(reference mini-batch)樣本計算歸一化的批量統計(batch statistics)。
最後,單邊標籤平滑(one-sided label smoothing)將判別器的目標從 1 替換為 0.9,使判別器的分類邊界變得平滑,從而阻止判別器過於自信,為生成器提供較差的梯度。Sønderby et al. [29] 改進了這個想法,他在將樣本饋送至判別器之前向樣本中新增噪聲來挑戰判別器。Sønderby et al. [29] 認為單邊標籤平滑偏向最佳判別器,而他們的技術——樣本噪聲(instance noise)使真樣本和假樣本的流形(manifold)更加接近,同時防止判別器輕易找到完全分離真假樣本的判別器邊界。該技巧在實踐中可以通過向合成影象和真實影象中新增高斯噪聲來實現,使標準差隨著時間逐漸減小。後來 Arjovsky et al. [26] 把向資料樣本新增噪聲來穩定訓練的過程形式化。
C. 成本函式的修正
1)GAN 成本函式的泛化:Nowozin et al. [30] 認為 GAN 訓練可以泛化至不僅最小化 JS 散度(Jensen-Shannon divergence),同時還最小化 f 散度的估計值。這些可以稱為 f-GAN。
2)防止梯度消失的其他成本函式:Arjovsky et al. [32] 提出 WGAN,一種帶有其他成本函式的 GAN 網路,該函式從近似 Wasserstein 距離中推匯出。
7. 討論
A. 開放性問題
1)Mode Collapse:如第四章所述,GAN 的普遍問題是生成器崩潰(generator collapsing),只能輸出一小類類似樣本(部分崩潰),最壞的情況是僅能生成一個簡單樣本(完全崩潰)[26], [48]。
2)訓練不穩定——鞍點:在 GAN 中,損失函式的海森矩陣(Hessian)變得非正定。因此,二階矩陣非正定只能找到一個鞍點而不是區域性極小值。
B. 結論
對 GAN 的研究興趣不只是因為它們可以從潛在空間和資料空間之間的對映和逆對映中學習深度、高度非線性的對映,還因為它們可以利用大量非標註影象資料,這種操作類似深度表徵學習。在 GAN 訓練中,存在很多理論和演算法的發展機會。有了深度網路的幫助,新應用存在大量機會。