使用tensorflow儲存、載入和使用模型
使用Tensorflow進行深度學習訓練的時候,需要對訓練好的網路模型和各種引數進行儲存,以便在此基礎上繼續訓練或者使用。介紹這方面的部落格有很多,我發現寫的最好的是這一篇官方英文介紹:
我對這篇文章進行了整理和彙總。
首先是模型的儲存。直接上程式碼:
需要說明的有以下幾點:#!/usr/bin/env python #-*- coding:utf-8 -*- ############################ #File Name: tut1_save.py #Author: Wang #Mail: [email protected] #Created Time:2017-08-30 11:04:25 ############################ import tensorflow as tf # prepare to feed input, i.e. feed_dict and placeholders w1 = tf.Variable(tf.random_normal(shape = [2]), name = 'w1') # name is very important in restoration w2 = tf.Variable(tf.random_normal(shape = [2]), name = 'w2') b1 = tf.Variable(2.0, name = 'bias1') feed_dict = {w1:[10,3], w2:[5,5]} # define a test operation that will be restored w3 = tf.add(w1, w2) # without name, w3 will not be stored w4 = tf.multiply(w3, b1, name = "op_to_restore") #saver = tf.train.Saver() saver = tf.train.Saver(max_to_keep = 4, keep_checkpoint_every_n_hours = 1) sess = tf.Session() sess.run(tf.global_variables_initializer()) print sess.run(w4, feed_dict) #saver.save(sess, 'my_test_model', global_step = 100) saver.save(sess, 'my_test_model') #saver.save(sess, 'my_test_model', global_step = 100, write_meta_graph = False)
1. 建立saver的時候可以指明要儲存的tensor,如果不指明,就會全部存下來。在這裡也可以指明最大儲存數量和checkpoint的記錄時間。具體細節看英文部落格。
2. saver.save()函式裡面可以設定global_step和write_meta_graph,meta儲存的是網路結構,只在開始執行程式的時候儲存一次即可,後續可以通過設定write_meta_graph = False加以限制。
3. 這個程式執行結束後,會在程式目錄下生成四個檔案,分別是.meta(儲存網路結構)、.data和.index(儲存訓練好的引數)、checkpoint(記錄最新的模型)。
下面是如何載入已經儲存的網路模型。這裡有兩種方法,第一種是saver.restore(sess, 'aaaa.ckpt'),這種方法的本質是讀取全部引數,並載入到已經定義好的網路結構上,因此相當於給網路的weights和biases賦值並執行tf.global_variables_initializer()。這種方法的缺點是使用前必須重寫網路結構,而且網路結構要和儲存的引數完全對上。第二種就比較高端了,直接把網路結構載入進來(.meta),上程式碼:
#!/usr/bin/env python
#-*- coding:utf-8 -*-
############################
#File Name: tut2_import.py
#Author: Wang
#Mail: [email protected]
#Created Time:2017-08-30 14:16:38
############################
import tensorflow as tf
sess = tf.Session()
new_saver = tf.train.import_meta_graph('my_test_model.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
print sess.run('w1:0')
使用載入的模型,輸入新資料,計算輸出,還是直接上程式碼:
#!/usr/bin/env python
#-*- coding:utf-8 -*-
############################
#File Name: tut3_reuse.py
#Author: Wang
#Mail: [email protected]
#Created Time:2017-08-30 14:33:35
############################
import tensorflow as tf
sess = tf.Session()
# First, load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))
# Second, access and create placeholders variables and create feed_dict to feed new data
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name('w1:0')
w2 = graph.get_tensor_by_name('w2:0')
feed_dict = {w1:[-1,1], w2:[4,6]}
# Access the op that want to run
op_to_restore = graph.get_tensor_by_name('op_to_restore:0')
print sess.run(op_to_restore, feed_dict) # ouotput: [6. 14.]
在已經載入的網路後繼續加入新的網路層:
import tensorflow as tf
sess=tf.Session()
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))
# Now, let's access and create placeholders variables and
# create feed-dict to feed new data
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}
#Now, access the op that you want to run.
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
#Add more to the current graph
add_on_op = tf.multiply(op_to_restore,2)
print sess.run(add_on_op,feed_dict)
#This will print 120.
對載入的網路進行區域性修改和處理(這個最麻煩,我還沒搞太明白,後續會繼續補充):
......
......
saver = tf.train.import_meta_graph('vgg.meta')
# Access the graph
graph = tf.get_default_graph()
## Prepare the feed_dict for feeding data for fine-tuning
#Access the appropriate output for fine-tuning
fc7= graph.get_tensor_by_name('fc7:0')
#use this if you only want to change gradients of the last layer
fc7 = tf.stop_gradient(fc7) # It's an identity function
fc7_shape= fc7.get_shape().as_list()
new_outputs=2
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05))
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs]))
output = tf.matmul(fc7, weights) + biases
pred = tf.nn.softmax(output)
# Now, you run this with fine-tuning data in sess.run()
有了這樣的方法,無論是自行訓練、載入模型繼續訓練、使用經典模型還是finetune經典模型抑或是載入網路跑前項,效果都是槓槓的。
相關推薦
使用tensorflow儲存、載入和使用模型
使用Tensorflow進行深度學習訓練的時候,需要對訓練好的網路模型和各種引數進行儲存,以便在此基礎上繼續訓練或者使用。介紹這方面的部落格有很多,我發現寫的最好的是這一篇官方英文介紹: 我對這篇文章進行了整理和彙總。 首先是模型的儲存。直接上程式碼: #!/usr/bi
mnist LSTM 訓練、測試,模型儲存、載入和識別
MNIST 字元資料庫每個字元(0-9) 對應一張28x28的一通道圖片,可以將圖片的每一列(行)當作特徵,所有行(列)當做一個序列。那麼可以通過輸入大小為28,時間長度為28的RNN(lstm)對字元建模。對於同一個字元,比如0,其行與行之間的動態變化可以
TensorFlow儲存、載入模型引數 | 原理描述及踩坑經驗總結
寫在前面 我之前使用的LSTM計算單元是根據其前向傳播的計算公式手動實現的,這兩天想要和TensorFlow自帶的tf.nn.rnn_cell.BasicLSTMCell()比較一下,看看哪個訓練速度快一些。在使用tf.nn.rnn_cell.BasicLSTMCell()進行建模的時候,遇到了模型儲存、載入
tensorflow 儲存及其載入
https://blog.csdn.net/thriving_fcl/article/details/75213361 同一個模型圖,可以根據根據輸入輸出任意組成signature_def,使模型的任意組合使用方便, signature_def了一種組織方式! 一個圖標籤下,可以任意組合很多種signatur
TensorFlow常量、變數和資料型別
TensorFlow 用張量這種資料結構來表示所有的資料。一個張量有一個靜態型別和動態型別的維數,張量可以在圖中的節點之間流通。 (1)TensorFlow中建立常量的方法: hello=tf.constant('hello,TensorFlow!',dtype=tf
OpenStack社群元件-儲存、備份和恢復
OpenStack發展到今天,從最開始的A版開始,到2017年8月份剛釋出的P版,已經發布了很多個版本。 OpenStack社群把OpenStack各個元件進行了歸類,一共分成了9大類。
Keras 儲存與載入網路模型
遇到問題: keras使用預訓練模型做訓練時遇到的如下程式碼: from keras.utils.data_utils import get_file WEIGHTS_PATH = 'https://github.com/fchollet/deep-lea
ios開發-懶載入和模型的封裝
一. ios開發中的懶載入 什麼是懶載入: 就是在需要資料的時候,再去載入資料,可以理解為延遲載入. OC中懶載入的形式 首先在控制器中宣告一個數組 @property (nonatomic, strong) NSArray *
Assetbundle打包、載入和提取資源的方式
打包AB的方式 注意:unity2017和5.X的版本API已經不一樣了,一個AB包裡可以有多個資源> BuildPipeline.BuildAssetBundles 會把所有標記的資源打包到指定目錄 一般情況下,我們會打包到Stream
如何儲存及載入Keras模型
我們不推薦使用pickle或cPickle來儲存Keras模型 你可以使用model.save(filepath)將Keras模型和權重儲存在一個HDF5檔案中,該檔案將包含: . 模型的結構,以便重構該模型 . 模型的權重 . 訓練配置(損失函式,優化
devexpress gridview 儲存、載入佈局
一次在糾結devexpress gridview 列隱藏後顯示時順序亂了的問題,QQ群裡一個朋友提示的一些資訊,記錄下來。 1、隱藏列後visibleIndex的值為-1; 2、一種解決方案: private void gcPlan_Load(object sender,
常用資料結構-二叉樹的鏈式儲存、建立和遍歷
1. 鏈式二叉樹簡介 二叉樹是資料結構——樹的一種,一般定義的樹是指有一個根節點,由此根節點向下分出數個分支結點,以此類推以至產生一堆結點。樹是一個具有代表性的非線性資料結構,所謂的非
tensorflow儲存模型、載入模型和提取模型引數和特徵圖
1.tf.train.latest_checkpoint('./model_data/')這一句最終返回的是一個字串,比如'./model_data/model-99991'這個方法本身還會做相應的檢查,比如checkpoint中最新的模型model_checkpoint_p
TensorFlow儲存和載入訓練模型
儲存:使用saver.save()方法儲存 載入:使用saver.restore()方法載入 下面是個完整例子: 儲存: import tensorflow as tf W = tf.Variable([[1, 1, 1], [2, 2, 2]], dtype=tf.float
《TensorFlow:實戰Google深度學習框架》——5.4 模型持久化(模型儲存、模型載入)
目錄 1、持久化程式碼實現 2、載入儲存的TensorFlow模型 3、載入部分變數 4、載入變數時重新命名 1、持久化程式碼實現 TensorFlow提供了一個非常簡單的API來儲存和還原一個神經網路模型。這個API就是tf.train.Saver類。一下程式碼給出了儲
tensorflow儲存和載入模型
× TF 儲存和載入模型 <!-- 作者區域 --> <div class="author"> <a class="avatar" href="/u/ff5c
Tensorflow學習筆記:變數作用域、模型的載入與儲存、執行緒與佇列實現多執行緒讀取樣本
# tensorflow變數作用域 用上下文語句規定作用域 with tf.variable_scope("作用域_name") ......
tensorflow儲存和載入訓練好的模型
儲存模型 saver = tf.train.Saver()建立一個saver物件 saver.save(sess,'path')將模型儲存在指定的path路徑中 該路徑中生成的檔案有四個, checkpoint檔案儲存了一個錄下多有的模型檔案列表, model.ckpt.
tensorflow 儲存和載入模型 -2
1、 我們經常在訓練完一個模型之後希望儲存訓練的結果,這些結果指的是模型的引數,以便下次迭代的訓練或者用作測試。Tensorflow針對這一需求提供了Saver類。 Saver類提供了向checkpoints檔案儲存和從checkpoints檔案中恢復變數的相關方法。C
tensorflow中儲存模型、載入模型做預測(不需要再定義網路結構)
下面用一個線下回歸模型來記載儲存模型、載入模型做預測 參考文章: 訓練一個線下回歸模型並儲存看程式碼: import tensorflow as tfimport numpy as