GAN原理手寫資料集生成
阿新 • • 發佈:2021-06-20
GAN原理介紹
- GAN 來源於博弈論中的零和博弈,博弈雙方,分別為生成模型與判別模型。
- 生成模型G捕捉樣本資料的分佈,用服從某一分佈例如正太,高斯分佈的噪聲z來生成一個類似真實訓練資料的樣本,追求的效果是越像真實越好。
- 判別模型是一個二分類器,判別樣本來自於訓練資料還是真實資料的概率。如果來自於真實樣本輸出大概率,如果來自於訓練資料,輸出小概率。
例項demo
- 以造小狗的假圖片為例。首先生成小狗圖片的模型,稱之為generator,還有一個判斷小狗圖片是否是真假的判別模型 discrimator。
- 首先輸入一個的噪聲,然後送入生成器,生成器的生成假圖
- 把真圖與假圖。進行拼接,然後打上標籤,真圖標籤是1,假圖標籤是0,送入鑑別器,鑑別器輸出屬於真實樣本與訓練樣本的概率。
實際GAN(手寫資料集為例)
資料預處理
匯入函式庫
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
劃分資料集
(train_image,train_labels),_=keras.datasets.mnist.load_data()
資料型別轉換于歸一化[-1,1]
train_images=2*tf.cast(train_image,tf.float32)/255.-1
expand_dims 設定通道,-1 加一維
train_images=2*tf.cast(train_image,tf.float32)/255.-1
# expand_dims 設定通道,-1 加一維
train_images=tf.expand_dims(train_images,-1)
train_images.shape
常用引數設定與資料集生成
Batch_Size=256 # 每回使用256 Buffer_Size=60000 #亂序範圍 # 構建demo使用的資料集 dataset=tf.data.Dataset.from_tensor_slices(train_images).shuffle(Buffer_Size).batch(Batch_Size)
GAN模型的生成器與鑑別器構建
- 定義圖片生成器
- BatchNormalization()計算出當前batch的每個channel的均值mean,計算出當前batch的每個channel的方差variance,令輸入減去均值再除以標準差delta,得到normalized輸出x-hat,最後乘以scale引數gamma,加上shift引數beta,得到最終變換後的輸出y
- https://www.jianshu.com/p/437fb1a5823e
- Keras中使用如Leaky ReLU等高階啟用函式的方法
- https://blog.csdn.net/hesongzefairy/article/details/86707352
def generator_model():
# 第一層
model=tf.keras.Sequential()
# 100->256
# 第一層
model.add(layers.Dense(256,input_shape=(100,),use_bias=False))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
# 第二層
#256->512
model.add(layers.Dense(512,use_bias=False))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
# 第三層
#512->28*28
model.add(layers.Dense(28*28,use_bias=False,activation="tanh"))
model.add(layers.BatchNormalization())
model.add(layers.Reshape([28,28,1]))
return model
定義判別器
def discriminator_model():
model=tf.keras.Sequential()
# 第一層
model.add(layers.Flatten())
model.add(layers.Dense(512,use_bias=False))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
# 第二層
model.add(layers.Dense(512, use_bias=False))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Dense(1))
# 輸出一個值
return model
設定損失函式
定義優化器
generator_opt=tf.keras.optimizers.Adam(0.0001)
discriminator_opt=tf.keras.optimizers.Adam(0.0001)
cross_entropy=keras.losses.BinaryCrossentropy(from_logits=True)
計算判別器損失
def discriminator_loss(real_out,fake_out):
real_loss=cross_entropy(tf.ones_like(real_out),real_out)
fake_loss=cross_entropy(tf.zeros_like(fake_out),fake_out)
return real_loss+fake_loss
計算生成器損失
def generator_loss(fake_out):
fake_loss = cross_entropy(tf.ones_like(fake_out), fake_out)
return fake_loss
定義訓練step
Epochs=100
input_dim=100
num_exp_to_generate=16
# 生成16*100
seed=tf.random.normal([num_exp_to_generate,input_dim])
# 定義訓練步驟
generator=generator_model()
discriminator=discriminator_model()
def train_step(images):
noise=tf.random.normal([Batch_Size,input_dim])
with tf.GradientTape() as gen_tape,tf.GradientTape() as dis_tape:
real_out=discriminator(images)
gen_img=generator(noise)
fake_out=discriminator(gen_img)
dis_loss=discriminator_loss(real_out,fake_out)
gen_loss=generator_loss(fake_out)
# 梯度下降引數計算
gen_gard=gen_tape.gradient(gen_loss,generator.trainable_variables)
dis_gard=dis_tape.gradient(dis_loss,discriminator.trainable_variables)
# 進行引數更新,並反傳
discriminator_opt.apply_gradients(zip(dis_gard,discriminator.trainable_variables))
generator_opt.apply_gradients(zip(gen_gard, generator.trainable_variables))
# 繪製訓練後的圖
def genrate_plot_image(gen_model,test_noise):
pre_images=gen_model(test_noise,training=False)
fig=plt.figure(figsize=(8,6))
for i in range(pre_images.shape[0]):
plt.subplot(4,4,i+1)
plt.imshow((pre_images[i,:,:,0]+1)/2*255.)
plt.axis('off')
plt.show()
epoch設定
def train(dataset,epochs):
for epoch in range(epochs):
print(epoch)
for image_batch in dataset:
train_step(image_batch)
print(epoch)
if epoch%10==0:
print(epoch)
genrate_plot_image(generator,seed)
main函式
if __name__ == '__main__':
train(dataset,500)
0
0
0
1
1
2
2
3
3
4
4
5
5
6
6
7
7
8
8
9
9
10
10
10
11
11
12
12
13
13
14
14
15
15
16
16
17
17
18
18
19
19
20
20
20
21
21
22
22
23
23
24
24
25
25
26
26
27
27
28
28
29
29
30
30
30
31
31
32
32
33
33
34
34
35
35
36
36
37
37
38
38
39
39
40
40
40
41
41
42
42
43
43
44
44
45
45
46
46
47
47
48
48
49
49
50
50
50
51
51
52
52
53
53
54
54
55
55
56
56
57
57
58
58
59
59
60
60
60
61
61
62
62
63
63
64
64
65
65
66
66
67
67
68
68
69
69
70
70
70
71
71
72
72
73
73
74
74
75
75
76
76
77
77
78
78
79
79
80
80
80
81
81
82
82
83
83
84
84
85
85
86
86
87
87
88
88
89
89
90
90
90
91
91
92
92
93
93
94
94
95
95
96
96
97
97
98
98
99
99
100
100
100
101
101
102
102
103
103
104
104
105
105
106
106
107
107
108
108
109
109
110
110
110
111
111
112
112
113
113
114
114
115
115
116
116
117
117
118
118
119
119
120
120
120
121
121
122
122
123
123
124
124
125
125
126
126
127
127
128
128
129
129
130
130
130
131
131
132
132
133
133
134
134
135
135
136
136
137
137
138
138
139
139
140
140
140
141
141
142
142
143
143
144
144
145
145
146
146
147
147
148
148
149
149
150
150
150
151
151
152
152
153
153
154
154
155
155
156
156
157
157
158
158
159
159
160
160
160
161
161
162
162
163
163
164
164
165
165
166
166
167
167
168
168
169
169
170
170
170
171
171
172
172
173
173
174
174
175
175
176
176
177
177
178
178
179
179
180
180
180
181
181
182
182
183
183
184
184
185
185
186
186
187
187
188
188
189
189
190
190
190
191
191
192
192
193
193
194
194
195
195
196
196
197
197
198
198
199
199
200
200
200
201
201
202
202
203
203
204
204
205
205
206
206
207
207
208
208
209
209
210
210
210
211
211
212
212
213
213
214
214
215
215
216
216
217
217
218
218
219
219
220
220
220
221
221
222
222
223
223
224
224
225
225
226
226
227
227
228
228
229
229
230
230
230
231
231
232
232
233
233
234
234
235
235
236
236
237
237
238
238
239
239
240
240
240
241
241
242
242
243
243
244
244
245
245
246
246
247
247
248
248
249
249
250
250
250
251
251
252
252
253
253
254
254
255
255
256
256
257
257
258
258
259
259
260
260
260
261
261
262
262
263
263
264
264
265
265
266
266
267
267
268
268
269
269
270
270
270
271
271
272
272
273
273
274
274
275
275
276
276
277
277
278
278
279
279
280
280
280
281
281
282
282
283
283
284
284
285
285
286
286
287
287
288
288
289
289
290
290
290
291
291
292
292
293
293
294
294
295
295
296
296
297
297
298
298
299
299
300
300
300
301
301
302
302
303
303
304
304
305
305
306
306
307
307
308
308
309
309
310
310
310
311
311
312
312
313
313
314
314
315
315
316
316
317
317
318
318
319
319
320
320
320
321
321
322
322
323
323
324
324
325
325
326
326
327
327
328
328
329
329
330
330
330
331
331
332
332
333
333
334
334
335
335
336
336
337
337
338
338
339
339
340
340
340
341
341
342
342
343
343
344
344
345
345
346
346
347
347
348
348
349
349
350
350
350
351
351
352
352
353
353
354
354
355
355
356
356
357
357
358
358
359
359
360
360
360
361
361
362
362
363
363
364
364
365
365
366
366
367
367
368
368
369
369
370
370
370
371
371
372
372
373
373
374
374
375
375
376
376
377
377
378
378
379
379
380
380
380
381
381
382
382
383
383
384
384
385
385
386
386
387
387
388
388
389
389
390
390
390
391
391
392
392
393
393
394
394
395
395
396
396
397
397
398
398
399
399
400
400
400
401
401
402
402
403
403
404
404
405
405
406
406
407
407
408
408
409
409
410
410
410
411
411
412
412
413
413
414
414
415
415
416
416
417
417
418
418
419
419
420
420
420
421
421
422
422
423
423
424
424
425
425
426
426
427
427
428
428
429
429
430
430
430
431
431
432
432
433
433
434
434
435
435
436
436
437
437
438
438
439
439
440
440
440
441
441
442
442
443
443
444
444
445
445
446
446
447
447
448
448
449
449
450
450
450
451
451
452
452
453
453
454
454
455
455
456
456
457
457
458
458
459
459
460
460
460
461
461
462
462
463
463
464
464
465
465
466
466
467
467
468
468
469
469
470
470
470
471
471
472
472
473
473
474
474
475
475
476
476
477
477
478
478
479
479
480
480
480
481
481
482
482
483
483
484
484
485
485
486
486
487
487
488
488
489
489
490
490
490
491
491
492
492
493
493
494
494
495
495
496
496
497
497
498
498
499
499
生成器描述
generator.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 256) 25600
_________________________________________________________________
batch_normalization (BatchNo (None, 256) 1024
_________________________________________________________________
leaky_re_lu (LeakyReLU) (None, 256) 0
_________________________________________________________________
dense_1 (Dense) (None, 512) 131072
_________________________________________________________________
batch_normalization_1 (Batch (None, 512) 2048
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU) (None, 512) 0
_________________________________________________________________
dense_2 (Dense) (None, 784) 401408
_________________________________________________________________
batch_normalization_2 (Batch (None, 784) 3136
_________________________________________________________________
reshape (Reshape) (None, 28, 28, 1) 0
=================================================================
Total params: 564,288
Trainable params: 561,184
Non-trainable params: 3,104
_________________________________________________________________