tensorflow學習系列六:mnist從訓練儲存模型再到載入模型測試
阿新 • • 發佈:2019-01-10
通過前面幾個系列的學習對tensorflow有了一個漸漸親切的感覺,本文主要是從tensorflow模型訓練與驗證的模型進行實踐一遍,以至於我們能夠通過tensorflow的訓練有一個整體的概念。下面主要是從訓練到儲存模型,然後載入模型進行預測。
# -*- coding: utf-8 -*- """ Created on Mon Jun 11 22:17:52 2018 func:搭建網路圖 @author: kuangyongjian """ import tensorflow as tf #構建圖 class Network(object): def __init__(self): self.learning_rate = 0.001 #機率已經訓練的次數 self.global_step = tf.Variable(0,trainable = False) self.x = tf.placeholder(tf.float32,[None,784]) self.label = tf.placeholder(tf.float32,[None,10]) self.w = tf.Variable(tf.zeros([784,10])) self.b = tf.Variable(tf.zeros([10])) self.y = tf.nn.softmax(tf.matmul(self.x,self.w) + self.b) self.loss = -tf.reduce_mean(self.label * tf.log(self.y) + 1e-10) self.train = tf.train.GradientDescentOptimizer(self.learning_rate).minimize(self.loss,global_step = self.global_step) predict = tf.equal(tf.argmax(self.label,1),tf.argmax(self.y,1)) self.accuracy = tf.reduce_mean(tf.cast(predict,tf.float32))
# -*- coding: utf-8 -*- """ Created on Tue Jun 12 09:16:52 2018 func:網路訓練,以及對應的模型儲存 @author: kuangyongjian """ import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data from model import Network CKPT_DIR = 'ckpt' class Train(object): def __init__(self): self.net = Network() self.sess = tf.Session() self.sess.run(tf.global_variables_initializer()) self.data = input_data.read_data_sets('../data_set',one_hot = True) def train(self): batch_size = 64 train_step = 10000 step = 0 #每隔1000步儲存一次模型 save_interval = 1000 #tf.train.Saver用於儲存訓練的結果 #max to keep 用於設定最多儲存多少個模型 #如果儲存的模型超過這個值,最舊的模型被刪除 saver = tf.train.Saver(max_to_keep = 10) ckpt = tf.train.get_checkpoint_state(CKPT_DIR) if ckpt and ckpt.get_checkpoint_state(CKPT_DIR): saver.restore(self.sess,ckpt.model_checkpoint_path) #讀取網路中的global_step的值,即當前已經訓練的次數 step = self.sess.run(self.net.global_step) print('continue from') print(' -> Minibatch update : ',step) while step < train_step: x,label = self.data.train.next_batch(batch_size) _,loss = self.sess.run([self.net.train,self.net.loss], feed_dict = {self.net.x: x,self.net.label:label}) step = self.sess.run(self.net.global_step) if step % 1000 == 0: print('第%6d步,當前loss: %.3f'%(step,loss)) #模型儲存在ckpt資料夾下 #模型檔名最後會增加global_step的值,比如2000的模型檔名為model-2000 if step % save_interval == 0: saver.save(self.sess,CKPT_DIR + '/model',global_step = step) def calculate_accuracy(self): test_x = self.data.test.images test_label = self.data.test.labels acc = self.sess.run(self.net.accuracy,feed_dict = {self.net.x:test_x,self.net.label:test_label}) print("準確率: %.3f,共測試了%d張圖片 " % (acc, len(test_label))) if __name__ == '__main__': model = Train() model.train() model.calculate_accuracy()
# -*- coding: utf-8 -*- """ Created on Tue Jun 12 09:36:55 2018 func:載入模型,進行模型測試 @author: kuangyongjian """ import tensorflow as tf import numpy as np from PIL import Image from model import Network CKPT_DIR = 'ckpt' class Predict(object): def __init__(self): #清除預設圖的堆疊,並設定全域性圖為預設圖 #若不進行清楚則在第二次載入的時候報錯,因為相當於重新載入了兩次 tf.reset_default_graph() self.net = Network() self.sess = tf.Session() self.sess.run(tf.global_variables_initializer()) #載入模型到sess中 self.restore() print('load susess') def restore(self): saver = tf.train.Saver() ckpt = tf.train.get_checkpoint_state(CKPT_DIR) print(ckpt.model_checkpoint_path) if ckpt and ckpt.model_checkpoint_path: saver.restore(self.sess,ckpt.model_checkpoint_path) else: raise FileNotFoundError('未儲存模型') def predict(self,image_path): #讀取圖片並灰度化 img = Image.open(image_path).convert('L') flatten_img = np.reshape(img,784) x = np.array([1 - flatten_img]) y = self.sess.run(self.net.y,feed_dict = {self.net.x:x}) print(image_path) print(' Predict digit',np.argmax(y[0])) if __name__ == '__main__': model = Predict() model.predict('0.png') model.predict('../test_images/1.png') model.predict('../test_images/4.png')
注意文中儲存模型和載入模型的方式,特別是在載入模型的時候比較容易出錯。
若有不當之處請指教,謝謝!