tensorflow-保存與讀取使用模型
阿新 • • 發佈:2019-02-07
手寫 odin nis 圖像 info one rom lac Coding 1、MNIST是深度學習的經典入門demo,他是由6萬張訓練圖片和1萬張測試圖片構成的,每張圖片都是2828大小(如下圖),而且都是黑白色構成(這裏的黑色是一個0-1的浮點數,黑色越深表示數值越靠近1),這些圖片是采集的不同的人手寫從0到9的數字。
下面先訓練識別數字模型
再保存模型
最後,讀取保存的模型,對數字圖片進行識別。
下面先訓練識別數字模型
再保存模型
最後,讀取保存的模型,對數字圖片進行識別。
2、保存模型
#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Sun Feb 3 20:28:26 2019 @author: myhaspl """ from tensorflow.examples.tutorials.mnist import input_data mnist=input_data.read_data_sets("MNIST_data/",one_hot=True) import tensorflow as tf import os x=tf.placeholder(tf.float32,[None,784]) w=tf.Variable(tf.zeros([784,10])) b=tf.Variable(tf.zeros([10])) y=tf.nn.softmax(tf.matmul(x,w)+b) y_=tf.placeholder(tf.float32,[None,10]) cross_entropy=tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y),reduction_indices=[1])) train_step=tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) init=tf.global_variables_initializer() sess=tf.Session() sess.run(init) saver=tf.train.Saver() for i in range(1000): sampleX,sampleY=mnist.train.next_batch(100) sess.run(train_step,feed_dict={x:sampleX,y_:sampleY}) print("訓練完成") print("保存生成模型...") model_dir="mnist_model" model_name="ml1" if not os.path.exists(model_dir): os.mkdir(model_dir) saver.save(sess,os.path.join(model_dir,model_name)) print("保存生成模型成功")
訓練完成
保存生成模型...
保存生成模型成功
[root@VM03centos learn]# ls mnist_model
checkpoint ml1.data-00000-of-00001 ml1.index ml1.meta
[root@VM03centos learn]# ls MNISTdata
t10k-images-idx3-ubyte.gz t10k-labels-idx1-ubyte.gz train-images-idx3-ubyte.gz train-labels-idx1-ubyte.gz
[root@VM03centos learn]#
讀取數字識別模型,對某個數字圖像進行識別
#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Sun Feb 3 20:28:26 2019 @author: myhaspl """ from tensorflow.examples.tutorials.mnist import input_data mnist=input_data.read_data_sets("MNIST_data/",one_hot=True) import tensorflow as tf x=tf.placeholder(tf.float32,[None,784]) w=tf.Variable(tf.zeros([784,10])) b=tf.Variable(tf.zeros([10])) y=tf.nn.softmax(tf.matmul(x,w)+b) y_=tf.placeholder(tf.float32,[None,10]) cross_entropy=tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y),reduction_indices=[1])) train_step=tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) init=tf.global_variables_initializer() sess=tf.Session() sess.run(init) saver=tf.train.Saver() print("讀取模型...") saver.restore(sess,"mnist_model/ml1") print("讀取模型完成") print("根據模型進行計算...") img=mnist.test.images[5] result=sess.run(y,feed_dict={x:img.reshape(1,784)}) print("預測輸出結果:{}".format(result)) print("預測結果:{}".format(result.argmax())) print("實際結果:{}".format(mnist.test.labels[5].argmax()))
讀取模型...
INFO:tensorflow:Restoring parameters from mnist_model/ml1
讀取模型完成
根據模型進行計算...
預測輸出結果:[[1.8999807e-06 9.8351490e-01 3.0815993e-03 4.3848301e-03 4.1427880e-05
1.6864968e-04 7.6594086e-05 4.5587993e-03 3.2991443e-03 8.7222963e-04]]
預測結果:1
實際結果:1
tensorflow-保存與讀取使用模型