深度學習tensorflow實戰筆記(4)利用儲存的VGG-16CNN網路模型提取特徵
阿新 • • 發佈:2019-01-09
前幾篇部落格寫了如何處理資料,如何把用自己的資料訓練VGG-16,如何把訓練好的模型儲存。而在實際應用中,並不是所有的操作都是為了分類的,有時候需要提取影象的特徵,那麼怎麼利用已經儲存的模型提取特徵呢?
“桃葉兒尖上尖,柳葉兒就遮滿了天”
測試資料轉換成tfrecords,教程:點選開啟連結
儲存訓練好的VGG-16模型,教程:點選開啟連結
1、讀取測試資料
首先把測試資料轉換成tfrecords,然後讀取出來,程式碼和前面部落格寫的一致:
#讀取檔案 def read_and_decode(filename,batch_size): #根據檔名生成一個佇列 filename_queue = tf.train.string_input_producer([filename]) reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) #返回檔名和檔案 features = tf.parse_single_example(serialized_example, features={ 'label': tf.FixedLenFeature([], tf.int64), 'img_raw' : tf.FixedLenFeature([], tf.string), }) img = tf.decode_raw(features['img_raw'], tf.uint8) img = tf.reshape(img, [300, 300, 3]) #影象歸一化大小 # img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 #影象減去均值處理 label = tf.cast(features['label'], tf.int32) #特殊處理 img_batch, label_batch = tf.train.shuffle_batch([img, label], batch_size= batch_size, num_threads=64, capacity=2000, min_after_dequeue=1500) return img_batch, tf.reshape(label_batch,[batch_size])
2、調取儲存的訓練好的VGG-16模型
最核心的部分是使用saver類中的restore方法,核心程式碼如下:
saver = tf.train.import_meta_graph("model/checkpoint/model.ckpt.meta") #注意路徑
saver.restore(sess, "./model/checkpoint/model.ckpt") #儲存模型的路徑
3、把測試資料傳進去模型提取特徵
利用的是graph.get_tensor_by_name(“名字”),則首先獲取模型中佔位符,然後將測試資料傳進去,這是最核心的地方,想要提取特徵也是通過名字獲取張量,比如要提取fc7的特徵,則fc7_features=graph.get_tensor_by_name("fc7:0") 。核心程式碼如下:
graph = tf.get_default_graph() #獲取恢復模型的圖模型 x_holder = graph.get_tensor_by_name("x_holder:0") # 獲取佔位符 fc7_features=graph.get_tensor_by_name("fc7:0") #獲取要提取的特徵,用該層的名字 keep_prob=graph.get_tensor_by_name("keep_prob:0") #同上 # 通過張量的名稱來獲取張量 print(sess.run(fc7_features,feed_dict={x_holder:image,keep_prob:dropout})) #給佔位符重新賦值,則可以提取輸入影象的特徵
4、完整的程式碼
整個過程,博主用了好幾天的時間才調通,中間的心酸歷程就不多說了,直接放完整的提取特徵程式碼吧,如果想用儲存的模型做分類,而不是提特徵,則舉一反三,我覺得並不難,修改一下即可:
完整程式碼:
# -*- coding: utf-8 -*-
"""
Created on Mon Apr 2 17:12:00 2018
@author: Heroin 高永標,upc
"""
import tensorflow as tf
#讀取檔案
def read_and_decode(filename,batch_size):
#根據檔名生成一個佇列
filename_queue = tf.train.string_input_producer([filename])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue) #返回檔名和檔案
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw' : tf.FixedLenFeature([], tf.string),
})
img = tf.decode_raw(features['img_raw'], tf.uint8)
img = tf.reshape(img, [300, 300, 3]) #影象歸一化大小
# img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 #影象減去均值處理
label = tf.cast(features['label'], tf.int32)
#特殊處理
img_batch, label_batch = tf.train.shuffle_batch([img, label],
batch_size= batch_size,
num_threads=64,
capacity=2000,
min_after_dequeue=1500)
return img_batch, tf.reshape(label_batch,[batch_size])
batch_size=4
dropout=1.0
tfrecords_file = 'train.tfrecords' #儲存的測試資料
BATCH_SIZE = 4
image_batch, label_batch = read_and_decode(tfrecords_file,BATCH_SIZE)
#print(image_batch)
#sess=tf.InteractiveSession()
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess = sess,coord = coord)
image,label=sess.run([image_batch,label_batch])
saver = tf.train.import_meta_graph("model/checkpoint/model.ckpt.meta") #儲存的模型路徑
saver.restore(sess, "./model/checkpoint/model.ckpt")
graph = tf.get_default_graph()
x_holder = graph.get_tensor_by_name("x_holder:0") # 獲取佔位符
fc7_features=graph.get_tensor_by_name("fc7:0") #獲取要提取的特徵,用名字
keep_prob=graph.get_tensor_by_name("keep_prob:0")
# 通過張量的名稱來獲取張量
print(sess.run(fc7_features,feed_dict={x_holder:image,keep_prob:dropout})) #給佔位符重新賦值
sess.close()