tensorflow學習筆記(北京大學) tf4_6.py 完全解析 滑動平均
阿新 • • 發佈:2018-12-19
#coding:utf-8
#tensorflow學習筆記(北京大學) tf4_6.py 完全解析 滑動平均
#QQ群:476842922(歡迎加群討論學習)
#如有錯誤還望留言指正,謝謝?
import tensorflow as tf
#1. 定義變數及滑動平均類
#定義一個32位浮點變數,初始值為0.0 這個程式碼就是不斷更新w1引數,優化w1引數,滑動平均做了個w1的影子
w1 = tf.Variable(0, dtype=tf.float32)
#定義num_updates(NN的迭代輪數),初始值為0,不可被優化(訓練),這個引數不訓練
global_step = tf.Variable (0, trainable=False)
#例項化滑動平均類,給衰減率為0.99,當前輪數global_step
MOVING_AVERAGE_DECAY = 0.99
ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)#滑動平均
#ema.apply後的括號裡是更新列表,每次執行sess.run(ema_op)時,對更新列表中的元素求滑動平均值。
#在實際應用中會使用tf.trainable_variables()自動將所有待訓練的引數彙總為列表
#ema_op = ema.apply([w1])
#apply (func [, args [, kwargs ]]) 函式用於當函式引數已經存在於一個元組或字典中時,間接地呼叫函式。
ema_op = ema.apply(tf.trainable_variables())
#2. 檢視不同迭代中變數取值的變化。
with tf.Session() as sess:
# 初始化
init_op = tf.global_variables_initializer()#初始化
sess.run(init_op)#計算初始化
#用ema.average(w1)獲取w1滑動平均值 (要執行多個節點,作為列表中的元素列出,寫在sess.run中)
#打印出當前引數w1和w1滑動平均值
print "current global_step:" , sess.run(global_step)#列印global_step
print "current w1", sess.run([w1, ema.average(w1)]) #計算滑動平均
# 引數w1的值賦為1
#tf.assign(A, new_number): 這個函式的功能主要是把A的值變為new_number
sess.run(tf.assign(w1, 1))
sess.run(ema_op)
print "current global_step:", sess.run(global_step)
print "current w1", sess.run([w1, ema.average(w1)])
# 更新global_step和w1的值,模擬出輪數為100時,引數w1變為10, 以下程式碼global_step保持為100,每次執行滑動平均操作,影子值會更新
sess.run(tf.assign(global_step, 100)) #設定global_step為100
sess.run(tf.assign(w1, 10))#設定W1為10
sess.run(ema_op)#執行ema_op
print "current global_step:", sess.run(global_step)#列印
print "current w1:", sess.run([w1, ema.average(w1)]) #列印
# 每次sess.run會更新一次w1的滑動平均值
sess.run(ema_op)
print "current global_step:" , sess.run(global_step)
print "current w1:", sess.run([w1, ema.average(w1)])
sess.run(ema_op)
print "current global_step:" , sess.run(global_step)
print "current w1:", sess.run([w1, ema.average(w1)])
sess.run(ema_op)
print "current global_step:" , sess.run(global_step)
print "current w1:", sess.run([w1, ema.average(w1)])
sess.run(ema_op)
print "current global_step:" , sess.run(global_step)
print "current w1:", sess.run([w1, ema.average(w1)])
sess.run(ema_op)
print "current global_step:" , sess.run(global_step)
print "current w1:", sess.run([w1, ema.average(w1)])
sess.run(ema_op)
print "current global_step:" , sess.run(global_step)
print "current w1:", sess.run([w1, ema.average(w1)])
#更改MOVING_AVERAGE_DECAY 為 0.1 看影子追隨速度
"""
current global_step: 0
current w1 [0.0, 0.0]
current global_step: 0
current w1 [1.0, 0.9]
current global_step: 100
current w1: [10.0, 1.6445453]
current global_step: 100
current w1: [10.0, 2.3281732]
current global_step: 100
current w1: [10.0, 2.955868]
current global_step: 100
current w1: [10.0, 3.532206]
current global_step: 100
current w1: [10.0, 4.061389]
current global_step: 100
current w1: [10.0, 4.547275]
current global_step: 100
current w1: [10.0, 4.9934072]
"""