1. 程式人生 > >GAN平衡G和D的訓練

GAN平衡G和D的訓練

如下所示GAN輸出的訓練日誌

Epoch: [ 1] [ 468/ 500] time: 4728.2349, d_loss: 4.07427502, fake_loss: 4.04856014, real_loss: 0.02571472, g_loss: 6.09300184
Epoch: [ 1] [ 469/ 500] time: 4731.1465, d_loss: 3.99250102, fake_loss: 3.96739197, real_loss: 0.02510899, g_loss: 5.91527224
Epoch: [ 1] [ 470/ 500] time: 4734.1464, d_loss: 2.95108366, fake_loss: 2.02483487, real_loss: 0.92624891, g_loss: 6.98860884
Epoch: [ 1] [ 471/ 500] time: 4737.0128, d_loss: 3.63805246, fake_loss: 3.60547066, real_loss: 0.03258171, g_loss: 6.16138554
Epoch: [ 1] [ 472/ 500] time: 4740.0376, d_loss: 2.79514408, fake_loss: 2.63232374, real_loss: 0.16282040, g_loss: 6.15056276
Epoch: [ 1] [ 473/ 500] time: 4742.9008, d_loss: 3.38974905, fake_loss: 3.32761765, real_loss: 0.06213144, g_loss: 5.87860966
Epoch: [ 1] [ 474/ 500] time: 4745.8732, d_loss: 3.48325801, fake_loss: 3.47971010, real_loss: 0.00354792, g_loss: 5.61753035
Epoch: [ 1] [ 475/ 500] time: 4748.7309, d_loss: 3.30880904, fake_loss: 2.85620737, real_loss: 0.45260173, g_loss: 6.53223276
Epoch: [ 1] [ 476/ 500] time: 4751.7712, d_loss: 3.08520770, fake_loss: 2.22547674, real_loss: 0.85973096, g_loss: 7.28343534
Epoch: [ 1] [ 477/ 500] time: 4754.6127, d_loss: 3.23970866, fake_loss: 3.21050930, real_loss: 0.02919928, g_loss: 6.13153982
Epoch: [ 1] [ 478/ 500] time: 4757.5939, d_loss: 3.72340441, fake_loss: 3.64948988, real_loss: 0.07391454, g_loss: 6.10789824
Epoch: [ 1] [ 479/ 500] time: 4760.4764, d_loss: 3.33632994, fake_loss: 3.21529436, real_loss: 0.12103555, g_loss: 5.90366840
Epoch: [ 1] [ 480/ 500] time: 4763.4599, d_loss: 3.04063058, fake_loss: 2.57415485, real_loss: 0.46647564, g_loss: 6.37139893
Epoch: [ 1] [ 481/ 500] time: 4766.3039, d_loss: 3.18814778, fake_loss: 3.18276262, real_loss: 0.00538522, g_loss: 5.81689548
Epoch: [ 1] [ 482/ 500] time: 4769.3117, d_loss: 2.88918209, fake_loss: 2.44751954, real_loss: 0.44166249, g_loss: 6.34483957
Epoch: [ 1] [ 483/ 500] time: 4772.1745, d_loss: 3.35039783, fake_loss: 3.18180418, real_loss: 0.16859373, g_loss: 6.23565578
Epoch: [ 1] [ 484/ 500] time: 4775.1729, d_loss: 4.13123512, fake_loss: 4.10078049, real_loss: 0.03045487, g_loss: 5.98916435
Epoch: [ 1] [ 485/ 500] time: 4778.0524, d_loss: 3.72344589, fake_loss: 3.65170097, real_loss: 0.07174486, g_loss: 6.19677067
Epoch: [ 1] [ 486/ 500] time: 4781.0614, d_loss: 2.75713587, fake_loss: 2.32209563, real_loss: 0.43504018, g_loss: 6.68624544
Epoch: [ 1] [ 487/ 500] time: 4783.9377, d_loss: 3.99661183, fake_loss: 3.99253273, real_loss: 0.00407907, g_loss: 5.86462307
Epoch: [ 1] [ 488/ 500] time: 4786.9539, d_loss: 2.94620275, fake_loss: 2.19436288, real_loss: 0.75183994, g_loss: 6.48085737
Epoch: [ 1] [ 489/ 500] time: 4791.4465, d_loss: 3.95884061, fake_loss: 3.90659142, real_loss: 0.05224927, g_loss: 6.15112591
Epoch: [ 1] [ 490/ 500] time: 4797.7769, d_loss: 3.83195233, fake_loss: 3.80750990, real_loss: 0.02444242, g_loss: 5.97502804
Epoch: [ 1] [ 491/ 500] time: 4801.4432, d_loss: 2.78101063, fake_loss: 2.34172320, real_loss: 0.43928733, g_loss: 6.86819363
Epoch: [ 1] [ 492/ 500] time: 4807.5275, d_loss: 3.35740018, fake_loss: 3.18634367, real_loss: 0.17105654, g_loss: 6.35232353
Epoch: [ 1] [ 493/ 500] time: 4812.1661, d_loss: 3.57710028, fake_loss: 3.53673315, real_loss: 0.04036722, g_loss: 6.07262945
Epoch: [ 1] [ 494/ 500] time: 4818.2569, d_loss: 3.33233833, fake_loss: 3.20232630, real_loss: 0.13001196, g_loss: 6.00899410
Epoch: [ 1] [ 495/ 500] time: 4823.2864, d_loss: 3.86144495, fake_loss: 3.81807637, real_loss: 0.04336853, g_loss: 5.99790001
Epoch: [ 1] [ 496/ 500] time: 4828.9741, d_loss: 3.29455090, fake_loss: 3.05517721, real_loss: 0.23937358, g_loss: 6.35387850
Epoch: [ 1] [ 497/ 500] time: 4834.2971, d_loss: 3.20185351, fake_loss: 2.88803077, real_loss: 0.31382266, g_loss: 6.17356730
Epoch: [ 1] [ 498/ 500] time: 4840.0634, d_loss: 2.85306931, fake_loss: 2.68211555, real_loss: 0.17095384, g_loss: 6.24095392
Epoch: [ 1] [ 499/ 500] time: 4844.8105, d_loss: 3.75695109, fake_loss: 3.74643278, real_loss: 0.01051831, g_loss: 6.07072544
[Sample] d_loss: 3.90224910, g_loss: 13.18194199

通過日誌輸出的損失大小,可以初步判斷D是訓練次數不夠的,因為D的吧G生成的影象很大概率判斷成真了,因為fake loss偏大。按照0.5作為比較標準的話,那麼判別器的loss因該是0.69左右才合適,明顯這個判別器對真樣本和生成樣本都很大概率判斷成真樣本了,說明不是很好,這樣的情況下需要增加D的訓練次數。這裡G的loss不是原始GAN的log(d(g(x))),所以,沒有評價標準,如果這裡的G loss是原始gan的loss的話,那麼這裡的G的loss就應該偏小才對,如果原始gan出現這種情況,那麼就該增加D的訓練次數,同理,要是real loss偏大,或者G的loss偏大,那麼就要增加G的訓練次數。

通過上述修正見下結果

Epoch: [ 0] [   1/ 500] time: 64.8840, d_loss: 0.85116410, fake_loss: 0.83811581, real_loss: 0.01304832, g_loss: 7.76426697
Epoch: [ 0] [   2/ 500] time: 67.8389, d_loss: 1.03823590, fake_loss: 1.03505087, real_loss: 0.00318509, g_loss: 7.51490021
Epoch: [ 0] [   3/ 500] time: 70.7466, d_loss: 0.32450223, fake_loss: 0.31290203, real_loss: 0.01160018, g_loss: 8.52584076
Epoch: [ 0] [   4/ 500] time: 73.6104, d_loss: 0.56391406, fake_loss: 0.55389690, real_loss: 0.01001717, g_loss: 7.78080940
Epoch: [ 0] [   5/ 500] time: 86.5351, d_loss: 0.56352222, fake_loss: 0.55384904, real_loss: 0.00967315, g_loss: 7.54131317
Epoch: [ 0] [   6/ 500] time: 90.5828, d_loss: 0.64782655, fake_loss: 0.63883263, real_loss: 0.00899389, g_loss: 7.57851410
Epoch: [ 0] [   7/ 500] time: 102.4376, d_loss: 0.90727967, fake_loss: 0.90105110, real_loss: 0.00622859, g_loss: 7.74293327
Epoch: [ 0] [   8/ 500] time: 107.6158, d_loss: 1.21272802, fake_loss: 1.20931554, real_loss: 0.00341244, g_loss: 7.27319956
Epoch: [ 0] [   9/ 500] time: 122.2964, d_loss: 0.86367071, fake_loss: 0.86000025, real_loss: 0.00367043, g_loss: 7.26012707
Epoch: [ 0] [  10/ 500] time: 133.1148, d_loss: 1.27289641, fake_loss: 1.27081740, real_loss: 0.00207904, g_loss: 7.01438046
Epoch: [ 0] [  11/ 500] time: 146.6726, d_loss: 1.10992002, fake_loss: 1.10505104, real_loss: 0.00486900, g_loss: 6.46944427

可以看到,fake loss明顯下降,兒real loss也下降了,那說明增加d的次數起到了很好的優化D的效果。

但是由於本人做的pix2pix的實驗,所以,G的loss無法跟D的loss進行完美的配合。因為G的loss主要是context loss主導的所以這種模型在平衡D的時候,怎麼保證D的反向傳播不是噪聲去幹擾G的生成是一個問題。。。。。。待研究明白了再說,或者有人指導下更好