1. 程式人生 > >tensorflow中moving average的正確用法

tensorflow中moving average的正確用法

ply mov 基礎 github log 方差 如果 均值 nbsp

一般在保存模型參數的時候,都會保存一份moving average,是取了不同叠代次數模型的移動平均,移動平均後的模型往往在性能上會比最後一次叠代保存的模型要好一些。

tensorflow-models項目中tutorials下cifar中相關的代碼寫的有點問題,在這寫下我自己的做法:

1.構建訓練模型時,添加如下代碼

1 variable_averages = tf.train.ExponentialMovingAverage(0.999, global_step)
2 variables_averages_op = variable_averages.apply(tf.trainable_variables())
3 ave_vars = [variable_averages.average(var) for var in tf.trainable_variables()] 4 train_op = tf.group(train_op, variables_averages_op)

第1行創建了一個指數移動平均類 variable_averages

第2行將variable_averages作用於當前模型中所有可訓練的變量上,得到 variables_averages_op操作符

第3行獲得所有可訓練變量對應的移動平均變量列表集合,後續用於保存模型

第4行在原有的訓練操作符基礎上,再添加variables_averages_op操作符,後續session執行run的時候,除了訓練時前向後向,梯度更新,還會對相應的變量做移動平均

2.開始訓練前,創建saver時,使用如下代碼

1 save_vars = tf.trainable_variables() + ave_vars
2 saver = tf.train.Saver(var_list=save_vars, max_to_keep=5)

第1行獲取所有需要保存的變量列表,這個時候 ave_vars就派上用場了。

第2行創建saver,指定var_list為所有可訓練變量及其對應的移動平均變量。

另外需要註意的是,如果你的模型中有bn或者類似層,包含有統計參數(均值、方差等),這些不屬於可訓練參數,還需要額外添加進save_vars中,可以參考我的這篇博客

3.在做inference的時候,利用如下代碼從checkpoint中恢復出移動平均模型

1 variable_averages = tf.train.ExponentialMovingAverage(0.999)
2 variables_to_restore = variable_averages.variables_to_restore()
3 saver = tf.train.Saver(variables_to_restore)
4 saver.restore(sess, model_path)

這幾行很簡單,就不做解釋了。

實際上,在inference的時候,剛剛的做法除了可以從checkpoint文件中恢復出移動平均參數,還可以恢復出對應叠代的模型參數,可以用來對比兩種方式,哪種效果更好,這時只需要將上面代碼的第3行改為saver = tf.train.Saver(tf.trainable_variables())即可(和保存時相同,如果有bn,也需要額外考慮)。在我的測試中,使用移動平均參數效果更佳。

tensorflow中moving average的正確用法