TensorFlow——學習率衰減的使用方法
阿新 • • 發佈:2019-06-02
在TensorFlow的優化器中, 都要設定學習率。學習率是在精度和速度之間找到一個平衡:
學習率太大,訓練的速度會有提升,但是結果的精度不夠,而且還可能導致不能收斂出現震盪的情況。
學習率太小,精度會有所提升,但是訓練的速度慢,耗費較多的時間。
因而我們可以使用退化學習率,又稱為衰減學習率。它的作用是在訓練的過程中,對學習率的值進行衰減,訓練到達一定程度後,使用小的學習率來提高精度。
在TensorFlow中的方法如下:tf.train.exponential_decay(),該方法的引數如下:
learning_rate, 初始的學習率的值
global_step, 迭代步數變數
decay_steps, 帶迭代多少次進行衰減
decay_rate, 迭代decay_steps次衰減的值
staircase=False, 預設為False,為True則不衰減
例如
tf.train.exponential_decay(initial_learning_rate, global_step=global_step, decay_steps=1000, decay_rate=0.9)表示沒經過1000次的迭代,學習率變為原來的0.9。
增大批次處理樣本的數量也可以起到退化學習率的作用。
下面我們寫了一個例子,每迭代10次,則較小為原來的0.5,程式碼如下:
import tensorflow as tf import numpy as np global_step = tf.Variable(0, trainable=False) initial_learning_rate = 0.1 learning_rate = tf.train.exponential_decay(initial_learning_rate, global_step=global_step, decay_steps=10, decay_rate=0.5) opt = tf.train.GradientDescentOptimizer(learning_rate) add_global = global_step.assign_add(1) with tf.Session() as sess: tf.global_variables_initializer().run() print(sess.run(learning_rate)) for i in range(50): g, rate = sess.run([add_global, learning_rate]) print(g, rate)
下面是程式的結果,我們發現沒10次就變為原來的一般:
隨後,又在MNIST上面進行了測試,發現使用學習率衰減使得準確率有較好的提升。程式碼如下:
import tensorflow as tf import numpy as np from tensorflow.examples.tutorials.mnist import input_data import matplotlib.pyplot as plt mnist = input_data.read_data_sets('MNIST_data', one_hot=True) tf.reset_default_graph() x = tf.placeholder(tf.float32, [None, 784]) y = tf.placeholder(tf.float32, [None, 10]) w = tf.Variable(tf.random_normal([784, 10])) b = tf.Variable(tf.zeros([10])) pred = tf.matmul(x, w) + b pred = tf.nn.softmax(pred) cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices=1)) global_step = tf.Variable(0, trainable=False) initial_learning_rate = 0.1 learning_rate = tf.train.exponential_decay(initial_learning_rate, global_step=global_step, decay_steps=1000, decay_rate=0.9) opt = tf.train.GradientDescentOptimizer(learning_rate) add_global = global_step.assign_add(1) optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) training_epochs = 50 batch_size = 100 display_step = 1 with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for epoch in range(training_epochs): avg_cost = 0 total_batch = int(mnist.train.num_examples/batch_size) for i in range(total_batch): batch_xs, batch_ys = mnist.train.next_batch(batch_size) _, c, add, rate = sess.run([optimizer, cost, add_global, learning_rate], feed_dict={x:batch_xs, y:batch_ys}) avg_cost += c / total_batch if (epoch + 1) % display_step == 0: print('epoch= ', epoch+1, ' cost= ', avg_cost, 'add_global=', add, 'rate=', rate) print('finished') correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) print('accuracy: ', accuracy.eval({x:mnist.test.images, y:mnist.test.labels}))
在使用衰減學習率我們最後的精度達到0.8897,在使用固定的學習率時,精度只有0.8586。
&n