TensorFlow學習筆記(六) tensorflow 中的儲存訓練的引數以及載入引數檢測test集
阿新 • • 發佈:2019-02-20
如何儲存訓練好的引數
以前面練習的一個小例子,來儲存訓練好的引數:
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data from tensorflow import float32 #載入資料,會自動通過一個指令碼下載好資料集 mnist = input_data.read_data_sets("MNIST_data", one_hot=True) #每個批次大小以及多少批次 batch_size = 100 n_batch = mnist.train.num_examples // batch_size #設定兩個佔位符 x = tf.placeholder(dtype=tf.float32, shape=[None, 784]) y = tf.placeholder(dtype=tf.float32, shape=[None, 10]) #建立一個簡單的神經網路 W = tf.Variable(tf.zeros([784, 10]),float32) b = tf.Variable(tf.zeros([10]),float32) prediction = tf.nn.softmax(tf.matmul(x, W)+b) #二次代價函式 loss = tf.reduce_mean(tf.square(y-prediction)) #使用梯度下降方法 train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) #初始化變數 init = tf.global_variables_initializer() #結果放在布林型列表中,其中argmax返回數列中最大值所在的位置 correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(prediction, 1)) #求準確性 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) saver = tf.train.Saver()#為下面儲存做準備 with tf.Session() as sess: sess.run(init) for epoch in range (21): for batch in range (n_batch): batch_xs, batch_ys = mnist.train.next_batch(batch_size) sess.run(train_step, feed_dict={x:batch_xs, y:batch_ys}) acc = sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels}) print("Iter" + str(epoch) + "Testing Accuracy" + str(acc)) #儲存模型 saver.save(sess,'net/my_net.ckpt')
你將會在你建立的/net資料夾中看到儲存的資料檔案
使用載入儲存好的引數檢驗準確性(和沒訓練過的引數比較)
其中沒訓練的引數的檢測準確率很低, 而通過匯入的引數檢測的卻很高import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data from tensorflow import float32 #載入資料,會自動通過一個指令碼下載好資料集 mnist = input_data.read_data_sets("MNIST_data", one_hot=True) #每個批次大小以及多少批次 batch_size = 100 n_batch = mnist.train.num_examples // batch_size #設定兩個佔位符 x = tf.placeholder(dtype=tf.float32, shape=[None, 784]) y = tf.placeholder(dtype=tf.float32, shape=[None, 10]) #建立一個簡單的神經網路 W = tf.Variable(tf.zeros([784, 10]),float32) b = tf.Variable(tf.zeros([10]),float32) prediction = tf.nn.softmax(tf.matmul(x, W)+b) #二次代價函式 loss = tf.reduce_mean(tf.square(y-prediction)) #使用梯度下降方法 train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) #初始化變數 init = tf.global_variables_initializer() #結果放在布林型列表中,其中argmax返回數列中最大值所在的位置 correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(prediction, 1)) #求準確性 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) saver = tf.train.Saver() with tf.Session() as sess: sess.run(init) print(sess.run(accuracy,feed_dict = {x:mnist.test.images, y:mnist.test.labels})) saver.restore(sess, 'net/my_net.ckpt') print(sess.run(accuracy,feed_dict={x:mnist.test.images, y:mnist.test.labels}))