機器學習與Tensorflow(7)——tf.train.Saver()、inception-v3的應用
阿新 • • 發佈:2019-01-08
1. tf.train.Saver()
- tf.train.Saver()是一個類,提供了變數、模型(也稱圖Graph)的儲存和恢復模型方法。
- TensorFlow是通過構造Graph的方式進行深度學習,任何操作(如卷積、池化等)都需要operator,儲存和恢復操作也不例外。
- 在tf.train.Saver()類初始化時,用於儲存和恢復的save和restore operator會被加入Graph。所以,下列類初始化操作應在搭建Graph時完成。
saver = tf.train.Saver()
TensorFlow的儲存和恢復分為兩種:
- 儲存和恢復變數
- 儲存和恢復模型
saver.save()儲存模型
#舉例:
儲存一個訓練好的手寫資料集識別模型
儲存在當前路徑的net資料夾中
1 import os 2 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 3 import tensorflow as tf 4 from tensorflow.examples.tutorials.mnist import input_data 5 6 #載入資料集 7 mnist = input_data.read_data_sets('MNIST_data', one_hot=True)View Code8 9 #每個批次100張照片 10 batch_size = 100 11 #計算一個需要多少個批次 12 n_batch = mnist.train.num_examples // batch_size 13 14 #定義兩個placeholder 15 x = tf.placeholder(tf.float32, [None, 784]) 16 y = tf.placeholder(tf.float32, [None, 10]) 17 18 #建立一個簡單的神經網路,輸入層784個神經元,輸出層10個神經元 19 W = tf.Variable(tf.zeros([784, 10])) 20 b = tf.Variable(tf.zeros([10]))21 prediction = tf.nn.softmax(tf.matmul(x, W) + b) 22 #代價函式 23 loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction)) 24 #使用梯度下降法 25 train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) 26 27 #初始化變數 28 init = tf.global_variables_initializer() 29 30 #結果存放在一個布林型列表中 31 correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1)) 32 33 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 34 35 saver = tf.train.Saver() 36 37 with tf.Session() as sess: 38 sess.run(init) 39 for epoch in range(11): 40 for batch in range(n_batch): 41 batch_xs, batch_ys = mnist.train.next_batch(batch_size) 42 sess.run(train_step, feed_dict={x:batch_xs, y:batch_ys}) 43 acc = sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels}) 44 print('Iter = ' + str(epoch) +', Testing Accuracy = ' + str(acc)) 45 #儲存模型 46 saver.save(sess, 'net/my_net.ckpt')
#儲存路徑中的檔案為: checkpoint:儲存當前網路狀態的檔案 my_net.ckpt.data-00000-of-00001 my_net.ckpt.index my_net.ckpt.meta:儲存Graph結構的檔案
#關於函式saver.save(),常用的引數就是前三個: save( sess, # 必需引數,Session物件 save_path, # 必需引數,儲存路徑 global_step=None, # 可以是Tensor, Tensor name, 整型數 latest_filename=None, # 協議緩衝檔名,預設為'checkpoint',不用管 meta_graph_suffix='meta', # 圖檔案的字尾,預設為'.meta',不用管 write_meta_graph=True, # 是否儲存Graph write_state=True, # 建議選擇預設值True strip_default_attrs=False # 是否跳過具有預設值的節點
saver.restore()載入已經訓練好的模型
#舉例:
通過載入剛才儲存的訓練好的手寫資料集識別模型進行手寫資料集的識別
1 import os 2 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 3 import tensorflow as tf 4 from tensorflow.examples.tutorials.mnist import input_data 5 6 mnist = input_data.read_data_sets('MNIST_data', one_hot=True) 7 batch_size = 100 8 n_batch = mnist.train.num_examples // batch_size 9 10 x = tf.placeholder(tf.float32, [None, 784]) 11 y = tf.placeholder(tf.float32, [None, 10]) 12 13 W = tf.Variable(tf.zeros([784, 10])) 14 b = tf.Variable(tf.zeros([10])) 15 prediction = tf.nn.softmax(tf.matmul(x, W) + b) 16 17 loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction)) 18 train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) 19 20 init = tf.global_variables_initializer() 21 22 correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1)) 23 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 24 25 saver = tf.train.Saver() 26 27 with tf.Session() as sess: 28 sess.run(init) 29 print(sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels})) 30 saver.restore(sess, 'net/my_net.ckpt') 31 print(sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels}))View Code
#執行結果: 0.098 0.9178 #直接得到的準確率相當低,通過載入訓練好的模型,識別準確率大大提升。
2. 下載google影象識別網路inception-v3並檢視結構
模型背景:
Inception(v3) 模型是Google 訓練好的最新一個影象識別模型,我們可以利用它來對我們的影象進行識別。
下載地址:
https://storage.googleapis.com/download.tensorflow.org/models/inception_dec_2015.zip
檔案描述:
- classify_image_graph_def.pb 檔案就是訓練好的Inception-v3模型。
- imagenet_synset_to_human_label_map.txt是類別檔案,包含人類標籤和uid之間的對映的檔案。
- imagenet_2012_challenge_label_map_proto.pbtxt是包含類號和uid之間的對映的檔案。
程式碼實現
1 import os 2 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 3 import tensorflow as tf 4 import tarfile 5 import requests 6 7 #inception模型下載地址 8 inception_pretrain_model_url = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' 9 10 #inception模型存放地址 11 inception_pretrain_model_dir = 'inception_model' 12 if not os.path.exists(inception_pretrain_model_dir): 13 os.makedirs(inception_pretrain_model_dir) 14 #獲取檔名,以及檔案路徑 15 filename = inception_pretrain_model_url.split('/')[-1] 16 filepath = os.path.join(inception_pretrain_model_dir, filename) 17 18 #下載模型 19 if not os.path.exists(filepath): 20 print('download: ', filename) 21 r = requests.get(inception_pretrain_model_url, stream=True) 22 with open(filepath, 'wb') as f: 23 for chunk in r.iter_content(chunk_size=1024): 24 if chunk: 25 f.write(chunk) 26 print('finish: ', filename) 27 #解壓檔案 28 tarfile.open(filepath, 'r:gz').extractall(inception_pretrain_model_dir) 29 30 #模型結構存放檔案 31 log_dir = 'inception_log' 32 if not os.path.exists(log_dir): 33 os.makedirs(log_dir) 34 35 #classify_image_graph_def.pb為google訓練好的模型 36 inception_graph_def_file = os.path.join(inception_pretrain_model_dir, 'classify_image_graph_def.pb') 37 with tf.Session() as sess: 38 #建立一個圖來存放google訓練好的模型 39 with tf.gfile.FastGFile(inception_graph_def_file, 'rb') as f: 40 graph_def = tf.GraphDef() 41 graph_def.ParseFromString(f.read()) 42 tf.import_graph_def(graph_def, name='') 43 #儲存圖的結構 44 writer = tf.summary.FileWriter(log_dir, sess.graph) 45 writer.close()View Code
#在下載過程中,下的特別慢,不知道是網路原因還是什麼 #程式總卡著不動 #所以我就手動下載壓縮包並進行解壓
下載結果
3. 使用inception-v3做各種影象的識別
#程式碼實現:
1 import os 2 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 3 import tensorflow as tf 4 import numpy as np 5 import re 6 from PIL import Image 7 import matplotlib.pyplot as plt 8 9 #這部分是對標籤號和類別號檔案進行一個預處理 10 11 class NodeLookup(object): 12 def __init__(self): 13 label_lookup_path = 'inception_model/imagenet_2012_challenge_label_map_proto.pbtxt' 14 uid_lookup_path = 'inception_model/imagenet_synset_to_human_label_map.txt' 15 self.node_lookup = self.load(label_lookup_path, uid_lookup_path) 16 def load(self, label_lookup_path, uid_lookup_path): 17 #載入分類字串n********對應分類名稱的檔案 18 proto_as_ascii_lines = tf.gfile.GFile(uid_lookup_path).readlines() 19 uid_to_human={} 20 #一行一行讀取資料 21 for line in proto_as_ascii_lines: 22 #去掉換行符 23 line = line.strip('\n') 24 #按照‘\t’進行分割 25 parsed_items = line.split('\t') 26 #獲取分類編號 27 uid = parsed_items[0] 28 #獲取分類名稱 29 human_string = parsed_items[1] 30 #儲存編號字串n********與分類名稱的對映關係 31 uid_to_human[uid] = human_string 32 33 #載入分類字串n********對應分類編號1-1000的檔案 34 proto_as_ascii = tf.gfile.GFile(label_lookup_path).readlines() 35 node_id_to_uid = {} 36 for line in proto_as_ascii: 37 if line.startswith(' target_class:'): 38 #獲取分類編號1-1000 39 target_class = int(line.split(': ')[1]) 40 if line.startswith(' target_class_string:'): 41 #獲取編號字串nn******** 42 target_class_string = line.split(': ')[1] 43 # 儲存分類編號1-1000與編號字串n********對映關係 44 node_id_to_uid[target_class] = target_class_string[1:-2] 45 # 建立分類編號1-1000對應分類名稱的對映關係 46 node_id_to_name = {} 47 for key, val in node_id_to_uid.items(): 48 #獲取分類名稱 49 name = uid_to_human[val] 50 # 建立分類編號1-1000到分類名稱的對映關係 51 node_id_to_name[key] = name 52 return node_id_to_name 53 # 傳入分類編號1-1000返回分類名稱 54 def id_to_string(self, node_id): 55 if node_id not in self.node_lookup: 56 return '' 57 return self.node_lookup[node_id] 58 59 #建立一個圖來存放google訓練好的模型 60 61 with tf.gfile.FastGFile('inception_model/classify_image_graph_def.pb', 'rb') as f: 62 graph_def = tf.GraphDef() 63 graph_def.ParseFromString(f.read()) 64 tf.import_graph_def(graph_def, name='') 65 66 with tf.Session() as sess: 67 softmax_tensor = sess.graph.get_tensor_by_name('softmax:0') 68 #遍歷目錄 69 for root, dirs, files in os.walk('images/'): 70 for file in files: 71 #載入圖片 72 image_data = tf.gfile.FastGFile(os.path.join(root, file), 'rb').read() 73 predictions = sess.run(softmax_tensor, {'DecodeJpeg/contents:0': image_data})#圖片格式是jpg格式 74 predictions = np.squeeze(predictions)#把結果轉為1維資料 75 76 #列印圖片路徑及名稱 77 image_path = os.path.join(root, file) 78 print(image_path) 79 80 # 顯示圖片 81 img = Image.open(image_path) 82 plt.imshow(img) 83 plt.axis('off') 84 plt.show() 85 86 #排序 87 top_k = predictions.argsort()[-5:][::-1] 88 node_lookup = NodeLookup() 89 for node_id in top_k: 90 # 獲取分類名稱 91 human_string = node_lookup.id_to_string(node_id) 92 # 獲取該分類的置信度 93 score = predictions[node_id] 94 print('%s(score = %.5f)' % (human_string, score)) 95 print()View Code
#執行結果:
images/1.jpg giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca(score = 0.87265) badger(score = 0.00260) lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens(score = 0.00205) brown bear, bruin, Ursus arctos(score = 0.00102) ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus(score = 0.00099) images/2.jpg French bulldog(score = 0.94474) bull mastiff(score = 0.00559) pug, pug-dog(score = 0.00352) Staffordshire bullterrier, Staffordshire bull terrier(score = 0.00165) boxer(score = 0.00116) images/3.jpg zebra(score = 0.94011) tiger, Panthera tigris(score = 0.00080) pencil box, pencil case(score = 0.00066) hartebeest(score = 0.00059) tiger cat(score = 0.00042) images/4.jpg hare(score = 0.87019) wood rabbit, cottontail, cottontail rabbit(score = 0.04802) Angora, Angora rabbit(score = 0.00612) wallaby, brush kangaroo(score = 0.00181) fox squirrel, eastern fox squirrel, Sciurus niger(score = 0.00056) images/5.jpg fox squirrel, eastern fox squirrel, Sciurus niger(score = 0.95047) marmot(score = 0.00265) mongoose(score = 0.00217) weasel(score = 0.00201) mink(score = 0.00199)