1. 程式人生 > >使用tensorflow儲存、載入和使用模型

使用tensorflow儲存、載入和使用模型

使用Tensorflow進行深度學習訓練的時候,需要對訓練好的網路模型和各種引數進行儲存,以便在此基礎上繼續訓練或者使用。介紹這方面的部落格有很多,我發現寫的最好的是這一篇官方英文介紹:

我對這篇文章進行了整理和彙總。

首先是模型的儲存。直接上程式碼:

#!/usr/bin/env python
#-*- coding:utf-8 -*-
############################
#File Name: tut1_save.py
#Author: Wang 
#Mail: [email protected]
#Created Time:2017-08-30 11:04:25
############################

import tensorflow as tf

# prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.Variable(tf.random_normal(shape = [2]), name = 'w1')  # name is very important in restoration
w2 = tf.Variable(tf.random_normal(shape = [2]), name = 'w2')
b1 = tf.Variable(2.0, name = 'bias1')
feed_dict = {w1:[10,3], w2:[5,5]}

# define a test operation that will be restored
w3 = tf.add(w1, w2)  # without name, w3 will not be stored
w4 = tf.multiply(w3, b1, name = "op_to_restore")

#saver = tf.train.Saver()
saver = tf.train.Saver(max_to_keep = 4, keep_checkpoint_every_n_hours = 1)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print sess.run(w4, feed_dict)
#saver.save(sess, 'my_test_model', global_step = 100)
saver.save(sess, 'my_test_model')
#saver.save(sess, 'my_test_model', global_step = 100, write_meta_graph = False)
需要說明的有以下幾點:

1. 建立saver的時候可以指明要儲存的tensor,如果不指明,就會全部存下來。在這裡也可以指明最大儲存數量和checkpoint的記錄時間。具體細節看英文部落格。

2. saver.save()函式裡面可以設定global_step和write_meta_graph,meta儲存的是網路結構,只在開始執行程式的時候儲存一次即可,後續可以通過設定write_meta_graph = False加以限制。

3. 這個程式執行結束後,會在程式目錄下生成四個檔案,分別是.meta(儲存網路結構)、.data和.index(儲存訓練好的引數)、checkpoint(記錄最新的模型)。

下面是如何載入已經儲存的網路模型。這裡有兩種方法,第一種是saver.restore(sess, 'aaaa.ckpt'),這種方法的本質是讀取全部引數,並載入到已經定義好的網路結構上,因此相當於給網路的weights和biases賦值並執行tf.global_variables_initializer()。這種方法的缺點是使用前必須重寫網路結構,而且網路結構要和儲存的引數完全對上。第二種就比較高端了,直接把網路結構載入進來(.meta),上程式碼:

#!/usr/bin/env python
#-*- coding:utf-8 -*-
############################
#File Name: tut2_import.py
#Author: Wang 
#Mail: 
[email protected]
#Created Time:2017-08-30 14:16:38 ############################ import tensorflow as tf sess = tf.Session() new_saver = tf.train.import_meta_graph('my_test_model.meta') new_saver.restore(sess, tf.train.latest_checkpoint('./')) print sess.run('w1:0')

使用載入的模型,輸入新資料,計算輸出,還是直接上程式碼:
#!/usr/bin/env python
#-*- coding:utf-8 -*-
############################
#File Name: tut3_reuse.py
#Author: Wang
#Mail: [email protected]
#Created Time:2017-08-30 14:33:35
############################

import tensorflow as tf

sess = tf.Session()

# First, load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))

# Second, access and create placeholders variables and create feed_dict to feed new data
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name('w1:0')
w2 = graph.get_tensor_by_name('w2:0')
feed_dict = {w1:[-1,1], w2:[4,6]}

# Access the op that want to run
op_to_restore = graph.get_tensor_by_name('op_to_restore:0')

print sess.run(op_to_restore, feed_dict)     # ouotput: [6. 14.]

在已經載入的網路後繼續加入新的網路層:
import tensorflow as tf

sess=tf.Session()    
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))


# Now, let's access and create placeholders variables and
# create feed-dict to feed new data

graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}

#Now, access the op that you want to run. 
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

#Add more to the current graph
add_on_op = tf.multiply(op_to_restore,2)

print sess.run(add_on_op,feed_dict)
#This will print 120.

對載入的網路進行區域性修改和處理(這個最麻煩,我還沒搞太明白,後續會繼續補充):
......
......
saver = tf.train.import_meta_graph('vgg.meta')
# Access the graph
graph = tf.get_default_graph()
## Prepare the feed_dict for feeding data for fine-tuning 

#Access the appropriate output for fine-tuning
fc7= graph.get_tensor_by_name('fc7:0')

#use this if you only want to change gradients of the last layer
fc7 = tf.stop_gradient(fc7) # It's an identity function
fc7_shape= fc7.get_shape().as_list()

new_outputs=2
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05))
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs]))
output = tf.matmul(fc7, weights) + biases
pred = tf.nn.softmax(output)

# Now, you run this with fine-tuning data in sess.run()

有了這樣的方法,無論是自行訓練、載入模型繼續訓練、使用經典模型還是finetune經典模型抑或是載入網路跑前項,效果都是槓槓的。

相關推薦

使用tensorflow儲存載入使用模型

使用Tensorflow進行深度學習訓練的時候,需要對訓練好的網路模型和各種引數進行儲存,以便在此基礎上繼續訓練或者使用。介紹這方面的部落格有很多,我發現寫的最好的是這一篇官方英文介紹: 我對這篇文章進行了整理和彙總。 首先是模型的儲存。直接上程式碼: #!/usr/bi

mnist LSTM 訓練測試,模型儲存載入識別

MNIST 字元資料庫每個字元(0-9) 對應一張28x28的一通道圖片,可以將圖片的每一列(行)當作特徵,所有行(列)當做一個序列。那麼可以通過輸入大小為28,時間長度為28的RNN(lstm)對字元建模。對於同一個字元,比如0,其行與行之間的動態變化可以

TensorFlow儲存載入模型引數 | 原理描述及踩坑經驗總結

寫在前面 我之前使用的LSTM計算單元是根據其前向傳播的計算公式手動實現的,這兩天想要和TensorFlow自帶的tf.nn.rnn_cell.BasicLSTMCell()比較一下,看看哪個訓練速度快一些。在使用tf.nn.rnn_cell.BasicLSTMCell()進行建模的時候,遇到了模型儲存、載入

tensorflow 儲存及其載入

https://blog.csdn.net/thriving_fcl/article/details/75213361 同一個模型圖,可以根據根據輸入輸出任意組成signature_def,使模型的任意組合使用方便, signature_def了一種組織方式! 一個圖標籤下,可以任意組合很多種signatur

TensorFlow常量變數資料型別

TensorFlow 用張量這種資料結構來表示所有的資料。一個張量有一個靜態型別和動態型別的維數,張量可以在圖中的節點之間流通。 (1)TensorFlow中建立常量的方法: hello=tf.constant('hello,TensorFlow!',dtype=tf

OpenStack社群元件-儲存備份恢復

    OpenStack發展到今天,從最開始的A版開始,到2017年8月份剛釋出的P版,已經發布了很多個版本。    OpenStack社群把OpenStack各個元件進行了歸類,一共分成了9大類。 

Keras 儲存載入網路模型

遇到問題: keras使用預訓練模型做訓練時遇到的如下程式碼: from keras.utils.data_utils import get_file WEIGHTS_PATH = 'https://github.com/fchollet/deep-lea

ios開發-懶載入模型的封裝

一. ios開發中的懶載入 什麼是懶載入: 就是在需要資料的時候,再去載入資料,可以理解為延遲載入. OC中懶載入的形式 首先在控制器中宣告一個數組 @property (nonatomic, strong) NSArray *

Assetbundle打包載入提取資源的方式

打包AB的方式 注意:unity2017和5.X的版本API已經不一樣了,一個AB包裡可以有多個資源> BuildPipeline.BuildAssetBundles 會把所有標記的資源打包到指定目錄 一般情況下,我們會打包到Stream

如何儲存載入Keras模型

我們不推薦使用pickle或cPickle來儲存Keras模型 你可以使用model.save(filepath)將Keras模型和權重儲存在一個HDF5檔案中,該檔案將包含: . 模型的結構,以便重構該模型 . 模型的權重 . 訓練配置(損失函式,優化

devexpress gridview 儲存載入佈局

一次在糾結devexpress gridview 列隱藏後顯示時順序亂了的問題,QQ群裡一個朋友提示的一些資訊,記錄下來。 1、隱藏列後visibleIndex的值為-1; 2、一種解決方案: private void gcPlan_Load(object sender,

常用資料結構-二叉樹的鏈式儲存建立遍歷

1. 鏈式二叉樹簡介         二叉樹是資料結構——樹的一種,一般定義的樹是指有一個根節點,由此根節點向下分出數個分支結點,以此類推以至產生一堆結點。樹是一個具有代表性的非線性資料結構,所謂的非

tensorflow儲存模型載入模型提取模型引數特徵圖

1.tf.train.latest_checkpoint('./model_data/')這一句最終返回的是一個字串,比如'./model_data/model-99991'這個方法本身還會做相應的檢查,比如checkpoint中最新的模型model_checkpoint_p

TensorFlow儲存載入訓練模型

儲存:使用saver.save()方法儲存 載入:使用saver.restore()方法載入 下面是個完整例子: 儲存: import tensorflow as tf W = tf.Variable([[1, 1, 1], [2, 2, 2]], dtype=tf.float

TensorFlow:實戰Google深度學習框架》——5.4 模型持久化(模型儲存模型載入

目錄 1、持久化程式碼實現 2、載入儲存的TensorFlow模型 3、載入部分變數 4、載入變數時重新命名 1、持久化程式碼實現 TensorFlow提供了一個非常簡單的API來儲存和還原一個神經網路模型。這個API就是tf.train.Saver類。一下程式碼給出了儲

tensorflow儲存載入模型

× TF 儲存和載入模型 <!-- 作者區域 --> <div class="author"> <a class="avatar" href="/u/ff5c

Tensorflow學習筆記:變數作用域模型載入儲存執行緒與佇列實現多執行緒讀取樣本

# tensorflow變數作用域     用上下文語句規定作用域     with tf.variable_scope("作用域_name")         ......

tensorflow儲存載入訓練好的模型

儲存模型 saver = tf.train.Saver()建立一個saver物件 saver.save(sess,'path')將模型儲存在指定的path路徑中 該路徑中生成的檔案有四個, checkpoint檔案儲存了一個錄下多有的模型檔案列表, model.ckpt.

tensorflow 儲存載入模型 -2

1、 我們經常在訓練完一個模型之後希望儲存訓練的結果,這些結果指的是模型的引數,以便下次迭代的訓練或者用作測試。Tensorflow針對這一需求提供了Saver類。 Saver類提供了向checkpoints檔案儲存和從checkpoints檔案中恢復變數的相關方法。C

tensorflow儲存模型載入模型做預測(不需要再定義網路結構)

下面用一個線下回歸模型來記載儲存模型、載入模型做預測 參考文章: 訓練一個線下回歸模型並儲存看程式碼: import tensorflow as tfimport numpy as