1. 程式人生 > >Pytorch: detach 和 retain_graph,和 GAN的原理解析

Pytorch: detach 和 retain_graph,和 GAN的原理解析

轉載知乎文章:https://zhuanlan.zhihu.com/p/43843694

寫的特別好,看的很明白.謝謝作者

本人觀察 Pytorch 下的生成對抗網路(GAN)的實現程式碼,發現不同人的實現細節略有不同,其中用到了 detach 和 retain_graph,本文通過兩個 gan 的程式碼,介紹它們的作用,並分析,不同的更新策略對程式效率的影響。

這兩個 GAN 的實現中,其更新策略不同,前者是先更新判別器(discriminator)引數,再更新生成器(generator)引數,這正是原始論文Generative Adversarial Networks 中的演算法(下圖所示);後者是先更新 generator 引數,再更新 discriminator 引數,它們的實現孰優孰劣呢?

 

GAN 的基本原理

首先回顧一下生成對抗網路演算法的基本原理,對於熟悉的同學,可以跳過這部分。限於篇幅,只介紹最原始的兩種 GAN 損失函式,不失一般性。

原始 GAN 的損失函式,主要來源於 binary cross-entropy loss:

L(x^{(i)},y^{(i)})=-y^{(i)}\cdot\mathrm{log}P(x^{(i)})-(1-y^{(i)})\cdot\mathrm{log}(1-P(x^{(i)}))\tag{1}

其中,y 為真實標籤,取 1 (正樣本) 或者 0 (負樣本),P(x) 為 x 屬於正樣本的概率。

判別器有兩種輸入,一種是真實的樣本 (x\in P_r,y=1);一種是生成器通過噪音 z 生成的假樣本 (G(z)\in P_g, y=0) 。判別器的工作是將這兩種來源的樣本區分開,故採用公式 (1) 計算其損失函式 L_d

分別把真實樣本和生成的假樣本代入公式 (1),對於每個正樣本,其損失函式只剩下:

L_{d_{real}}(x^{(i)}\in P_r,y^{(i)}=1)=-\mathrm{log}D(x^{(i)})\tag{2}

而對於每個負樣本,損失函式為:

L_{d_{fake}}(G(z^{(i)}),y^{(i)}=0)=-\mathrm{log}\left(1-D(G(z^{(i)}))\right)\tag{3}

(這裡把公式 (1) 的 P(x) 改為 D(x),表示是 discriminator 的輸出,和上圖中的演算法保持一致)。把這兩個損失函式加起來,算一下平均值:

L_{d_{total}}=-\frac{1}{m}\sum_{i=1}^{m}\mathrm{log}D(x^{(i)})+\mathrm{log}\left(1-D(G(z^{(i)}))\right)\tag{4}

然後,將這個損失函式對 discriminator 的引數 \theta_d 求導,即得到上圖演算法中的梯度:

\frac{\mathrm{d}L_{d_{total}}}{\mathrm{d}\theta_d}=-\nabla_{\theta_d}\frac{1}{m}\sum_{i=1}^{m}\mathrm{log}D(x^{(i)})+\mathrm{log}\left(1-D(G(z^{(i)}))\right)\tag{5}

注意演算法截圖中用梯度上升優化,所以和上式差一個負號

生成器輸入噪聲 z^{(i)} ,輸出一個假資料 G(z^{(i)}) 。它希望這個假資料能騙過判別器。而上文判別器對假資料的損失函式為:

L_{d_{fake}}(G(z^{(i)}),y^{(i)}=0)=-\mathrm{log}\left(1-D(G(z^{(i)}))\right)\tag{3}

這個損失越大,意味著判別器的效能越差,同時也意味著生成器的效能越好。也就是說,我們希望朝著增大該損失的方向來調整生成器的引數。如果採用梯度下降優化生成器,也就是要減小它的相反數,即生成器的損失函式:

L_g(z^{(i)})=\mathrm{log}\left(1-D(G(z^{(i)}))\right)\tag{6}

這就是原始 GAN 生成器的第一種損失函式,對其取平均再對生成器引數 \theta_g 求梯度,得到演算法截圖裡面的公式:

\frac{\mathrm{d}L_{g_{total}}}{\mathrm{d}\theta_g}=\nabla_{\theta_g}\frac{1}{m}\sum_{i=1}^m\mathrm{log}\left(1-D(G(z^{(i)}))\right)\tag{7}

演算法截圖中用梯度下降優化,所以和上式相同

但是,公式 (6) 的損失函式在訓練初期梯度會特別小。因為剛開始訓練時,生成器生成的資料十分不真實,導致判別器很容易判斷出來它是假資料。因此 D(G(z^{(i)})) 是個接近於0 的常數。而損失函式 (6) 對 generator 引數 \theta_g 的導數為:

\frac{\mathrm{d}L_g(z^{(i)})}{\mathrm{d}\theta_g}=-\frac{1}{1-D(G(z^{(i)}))}\cdot\frac{\mathrm{d}D(G(z^{(i)}))}{\mathrm{d}G(z^{(i)})}\cdot\frac{\mathrm{d}G(z^{(i)})}{\mathrm{d}\theta_g}\tag{8}

由於 D(G(z^{(i)})) 是接近於0的常數,所以 (8) 中等號右邊第一項近似為 -1,而中間項為接近 0 的數(因為 D(G(z^{(i)})) 是接近 0 的常函式,而常函式的梯度為 0)。

因此,GAN 的作者又給出了第二種生成器損失函式:

L_g(z^{(i)})=-\mathrm{log}(D(G(z^{(i)})))\tag{9}

這個損失函式很好理解,就是要生成這樣的資料:使得輸入 discriminator 後,輸出一個較大的概率,即對生成器來說 \mathrm{log}(D(G(z^{(i)}))) 越大越好,取它的相反數,也就是公式 (9) 越小越好。我們計算公式 (9) 的損失函式對生成器引數 \theta_g 的導數:

\frac{\mathrm{d}L_g(z^{(i)})}{\mathrm{d}\theta_g}=-\frac{1}{D(G(z^{(i)}))}\cdot\frac{\mathrm{d}D(G(z^{(i)}))}{\mathrm{d}G(z^{(i)})}\cdot\frac{\mathrm{d}G(z^{(i)})}{\mathrm{d}\theta_g}\tag{10}

和 (8) 比只有等號右邊第一項不同,當 D(G(z^{(i)})) 接近 0 時, 1/D(G(z^{(i)})) 接近無窮大,和中間那項相乘,一定程度抵消了中間項接近 0 的問題,使計算出來的梯度增大了。

在實現過程中,基本上都是採用公式 (9) 作為生成器損失函式,因為這十分方便,只要在計算生成器損失函式時,把虛假資料的標籤標記為真,代入 Binary cross-entropy loss 的公式即可。

更新策略

下面進入本文正題,即,在 pytorch 中,detach 和 retain_graph 是幹什麼用的?本文將藉助兩段 GAN 的實現程式碼,來舉例介紹它們的作用。

第一段程式碼先更新判別器,再更新生成器。我們分析迴圈中一個 step 的程式碼:

valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device) # 真實標籤,都是1
fake = torch.Tensor(imgs.size(0), 1).fill_(0.0).to(device) # 假標籤,都是0
#----------
# 訓練判別器
#----------
real_imgs = imgs.to(device)
z = torch.randn((imgs.shape[0], 100)).to(device) # 噪聲
gen_imgs = generator(z) # 從噪聲中生成假資料
pred_gen = discriminator(gen_imgs) # 判別器對假資料的輸出
pred_real = discriminator(real_imgs) # 判別器對真資料的輸出
optimizer_D.zero_grad() # 把判別器中所有引數的梯度歸零
real_loss = adversarial_loss(pred_real, valid) # 判別器對真實樣本的損失
fake_loss = adversarial_loss(pred_gen, fake) # 判別器對假樣本的損失
d_loss = (real_loss + fake_loss) / 2  # 兩項損失相加取平均
# 下面這行程式碼十分重要,將在正文著重講解
d_loss.backward(retain_graph=True) # retain_graph 十分重要,否則計算圖記憶體將會被釋放
optimizer_D.step() # 判別器引數更新
#---------
#訓練生成器
#---------
g_loss = adversarial_loss(pred_gen, valid) # 生成器的損失函式
optimizer_G.zero_grad() # 生成器引數梯度歸零
g_loss.backward() # 生成器的損失函式梯度反向傳播
optimizer_G.step() # 生成器引數更新

上面的程式碼中 d_loss.backward(retain_graph=True) 十分關鍵,它用於反向傳播 discriminator loss 的梯度。那麼,具體傳播到什麼地方去呢?

這要看 d_loss 是由哪幾部分構成的:real_loss 和 fake_loss,而 fake_loss 又是從 noise 經過 generator 來的。因此,d_loss 反向傳播,將一傳到底,不但計算了 discriminator 的梯度,同時還計算了 generator 的梯度,雖然這一步optimizer_D.step()只更新 discriminator 的引數。

也正是這個原因,下面在更新 generator 引數時,要先將生成器引數的梯度重新歸零,避免受到 discriminator loss 回傳過來的梯度影響。注意:它在反向傳播時,設定了 retain graph = 0 這個引數。它的作用是保持計算圖,因為 pytorch 預設一個計算圖只計算一次反向傳播,反向傳播後,這個計算圖的記憶體就被釋放了。而後面的 generator 算梯度時還要用到這個計算圖,所以用這個引數控制計算圖不被釋放。

generator 的 損失在回傳時,同樣要經過 discriminator 網路才能傳遞迴自身(系統從輸入噪聲到 Discriminator 輸出,從頭到尾只有一次前向傳播,而有兩次反向傳播,故在第一次反向傳播時,要保持計算圖不被釋放)。因此,在回傳梯度時,同樣也計算了一遍 discriminator 的引數梯度,只不過這次 discriminator 的引數不更新,只更新 generator 的引數,即 optimizer_G.step()。同時,我們看到,下一個 step 首先將 discriminator 的梯度重置為 0,就是為了防止 generator loss 反向傳播時順帶計算的梯度對其造成影響(還有上一步 discriminator loss 回傳時累積的梯度)。

綜上,我們看到,為了完成一步引數更新,我們進行了兩次反向傳播,第一次反向傳播為了更新 discriminator 的引數,但多餘計算了 generator 的梯度。第二次反向傳播為了更新 generator 的引數,但是不得不多計算 discriminator 的梯度。

 

對於先更新生成器引數的情況,我們也分析其迴圈中一個 step 的程式碼:

valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False) # 真實樣本的標籤,都是 1
fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False) # 生成樣本的標籤,都是 0
real_imgs = Variable(imgs.type(Tensor))
#-----------
# 訓練生成器
#-----------
optimizer_G.zero_grad() # 生成器引數梯度歸零
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) # 噪聲
gen_imgs = generator(z) # 根據噪聲生成虛假樣本
g_loss = adversarial_loss(discriminator(gen_imgs), valid) # 用真實的標籤+假樣本,計算生成器損失
g_loss.backward() # 生成器梯度反向傳播,反向傳播經過了判別器,故此時判別器引數也有梯度
optimizer_G.step() # 生成器引數更新,判別器引數雖然有梯度,但是這一步不能更新判別器
#----------
# 訓練判別器
#----------
optimizer_D.zero_grad() # 把生成器損失函式梯度反向傳播時,順帶計算的判別器引數梯度清空
real_loss = adversarial_loss(discriminator(real_imgs), valid) # 真樣本+真標籤:判別器損失
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake) # 假樣本+假標籤:判別器損失
d_loss = (real_loss + fake_loss) / 2  # 判別器總的損失函式
d_loss.backward() # 判別器損失回傳
optimizer_D.step() # 判別器引數更新

上述程式碼先更新生成器引數,再更新判別器引數。那麼除了順序上的區別,和先更新判別器,再更新生成器,還有什麼不同?答案是,計算圖的遍歷次數不同。

為了更新生成器引數,用生成器的損失函式計算梯度,然後反向傳播,傳播圖中經過了discriminator,根據鏈式法則,不得不順帶計算一下判別器的引數梯度,雖然在這一步不會更新判別器引數。反向傳播過後,noise 到 fake image 再到 discriminator 的輸出這個前向傳播的計算圖就被釋放掉了,後面也不會再用到。

接著更新判別器引數,此時注意到,我們輸入判別器的是兩部分,一部分是真實資料,另一部分是生成器的輸出,也就是假資料。注意觀察細節,在判別器前向傳播過程,輸入的假資料被 detach 了:discriminator(gen_imgs.detach()),detach 的意思是,這個資料和生成它的計算圖“脫鉤”了,即梯度傳到它那個地方就停了,不再繼續往前傳播(實際上也不會再往前傳播了,因為 generator 的計算圖在第一次反向傳播過後就被釋放了)。因此,判別器梯度反向傳播,就到它自己身上為止。

因此,比起第一種策略,要少計算一次 generator 的所有引數的梯度,同時,也不必刻意儲存一次計算圖,佔用不必要的記憶體。

但需要注意的是,在第一種策略中,noise 從 generator 輸入,到 discriminator 輸出,只經歷了一次前向傳播,discriminator 端的輸出,被用了兩次,一次是計算 discriminator 的損失函式,另一次是計算 generator 的損失函式。

而在第二種策略中,noise 從 generator 輸入,到discriminator 輸出,計算 generator 損失,回傳,這一步更新了 generator 的引數,並釋放了計算圖。下一步更新 discriminator 的引數時,generator 的輸出經過 detach 後,又通過了一遍 discriminator,相當於,generator 的輸出前後兩次通過了 discriminator ,得到相同的輸出。顯然,這也是冗餘的。

 

綜上,這兩段程式碼各有利弊:

第一段程式碼,好處是 noise 只進行了一次前向傳播,缺點是,更新 discriminator 引數時,多計算了一次 generator 的梯度,同時,第一次更新 discriminator 需要保留計算圖,保證算 generator loss 時計算圖不被銷燬。

第二段程式碼,好處是通過先更新 generator ,使更新後的前向傳播計算圖可以放心被銷燬,因此不用保留計算圖佔用記憶體。同時,在更新 discriminator 的時候,也不會像上面的那段程式碼,計算冗餘的 generator 的梯度。缺點是,在 discriminator 上,對 generator 的輸出算了兩次前向傳播,第二次又產生了新的計算圖(但比第一次的小)。

一個多計算了一次 generator 梯度,一個多計算一次 discriminator 前向傳播。因此,兩者差別不大。如果 discriminator 比generator 複雜,那麼應該採取第一種策略,如果 discriminator 比 generator 簡單,那麼應該採取第二種策略,通常情況下,discriminator 要比 generator 簡單,故應該採取第二種策略居多。

 

但是第二種先更新generator,再更新 discriminator 總是給人感覺怪怪得,因為 generator 的更新需要 discriminator 提供準確的 loss 和 gradient,否則豈不是在瞎更新?

 

還有一種沒提到的策略,noise 從 generator 輸入,輸出 fake data,然後 detach 一下,隨著 true data 一起輸入 discriminator,計算 discriminator 損失,並更新 discriminator 引數。接下來,再把沒經過 detach 的 fake data 輸入到discriminator 中,計算 generator loss,再反向傳播梯度,更新 generator 的引數。這種策略,計算了兩次 discriminator 梯度,一次 generator 梯度。感覺這種比較符合先更新 discriminator 的習慣。缺點是,之前的 generator 生成的計算圖得保留著,直到 discriminator 更新完,再釋放。不像策略二,馬上用完馬上釋放。綜合來說,還是策略二最好,策略三其次,策略一最差(差在多計算一次 generator gradient 上,而通常多計算一次 generator gradient 的運算量比多計算一次 discriminator 前向傳播的運算量大),因此,detach 還是很有必要的。