1. 程式人生 > >Tensorflow 模型持久化 Model Persistence

Tensorflow 模型持久化 Model Persistence

Methods of tf model persistence
如果我們的神經網路比較複雜,訓練資料比較多,那麼我們的模型訓練就會耗時很長,如果在訓練過程中出現某些不可預計的錯誤,導致我們的訓練意外終止,那麼我們將會前功盡棄。為了避免這個問題,我們就可以通過模型持久化(儲存為CKPT格式)來暫存我們訓練過程中的臨時資料。
如果我們訓練的模型需要提供給使用者做離線的預測,那麼我們只需要前向傳播的過程,只需得到預測值就可以了,這個時候我們就可以通過模型持久化(儲存為PB格式)只儲存前向傳播中需要的變數並將變數的值固定下來,這個時候只需使用者提供一個輸入,我們就可以通過模型得到一個輸出給使用者。

1\儲存為 CKPT 格式的模型
定義運算過程
宣告並得到一個 Saver
通過 Saver.save 儲存模型

# coding=UTF-8 支援中文編碼格式
import tensorflow as tf
import shutil
import os.path

MODEL_DIR = "model/ckpt"
MODEL_NAME = "model.ckpt"
# if os.path.exists(MODEL_DIR): 刪除目錄
#     shutil.rmtree(MODEL_DIR)
if not tf.gfile.Exists(MODEL_DIR): #建立目錄
tf.gfile.MakeDirs(MODEL_DIR) #下面的過程你可以替換成CNN、RNN等你想做的訓練過程,這裡只是簡單的一個計算公式 input_holder = tf.placeholder(tf.float32, shape=[1], name="input_holder") #輸入佔位符,並指定名字,後續模型讀取可能會用的 W1 = tf.Variable(tf.constant(5.0, shape=[1]), name="W1") B1 = tf.Variable(tf.constant(1.0, shape=[1]), name="B1") _y = (input_holder * W1) + B1 predictions = tf.greater
(_y, 50, name="predictions") #輸出節點名字,後續模型讀取會用到,比50大返回true,否則返回false init = tf.global_variables_initializer() saver = tf.train.Saver() #宣告saver用於儲存模型 with tf.Session() as sess: sess.run(init) print "predictions : ", sess.run(predictions, feed_dict={input_holder: [10.0]}) #輸入一個數據測試一下 saver.save(sess, os.path.join(MODEL_DIR, MODEL_NAME)) #模型儲存 print("%d ops in the final graph." % len(tf.get_default_graph().as_graph_def().node)) #得到當前圖有幾個操作節點 for op in tf.get_default_graph().get_operations(): #列印模型節點資訊 print (op.name, op.values())

執行後生成的檔案如下:
checkpoint : 記錄目錄下所有模型檔案列表
ckpt.data : 儲存模型中每個變數的取值
ckpt.meta : 儲存整個計算圖的結構

2\儲存為 PB 格式模型

定義運算過程
通過 get_default_graph().as_graph_def() 得到當前圖的計算節點資訊
通過 graph_util.convert_variables_to_constants 將相關節點的values固定
通過 tf.gfile.GFile 進行模型持久化

# coding=UTF-8
import tensorflow as tf
import shutil
import os.path
from tensorflow.python.framework import graph_util

# MODEL_DIR = "model/pb"
# MODEL_NAME = "addmodel.pb"

# if os.path.exists(MODEL_DIR): 刪除目錄
#     shutil.rmtree(MODEL_DIR)
#
# if not tf.gfile.Exists(MODEL_DIR): #建立目錄
#     tf.gfile.MakeDirs(MODEL_DIR)

output_graph = "model/pb/add_model.pb"

#下面的過程你可以替換成CNN、RNN等你想做的訓練過程,這裡只是簡單的一個計算公式
input_holder = tf.placeholder(tf.float32, shape=[1], name="input_holder")
W1 = tf.Variable(tf.constant(5.0, shape=[1]), name="W1")
B1 = tf.Variable(tf.constant(1.0, shape=[1]), name="B1")
_y = (input_holder * W1) + B1
# predictions = tf.greater(_y, 50, name="predictions") #比50大返回true,否則返回false
predictions = tf.add(_y, 10,name="predictions") #做一個加法運算

init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    print "predictions : ", sess.run(predictions, feed_dict={input_holder: [10.0]})
    graph_def = tf.get_default_graph().as_graph_def() #得到當前的圖的 GraphDef 部分,通過這個部分就可以完成重輸入層到輸出層的計算過程

    output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化,將變數值固定
        sess,
        graph_def,
        ["predictions"] #需要儲存節點的名字
    )
    with tf.gfile.GFile(output_graph, "wb") as f:  # 儲存模型
        f.write(output_graph_def.SerializeToString())  # 序列化輸出
    print("%d ops in the final graph." % len(output_graph_def.node))
    print (predictions)

# for op in tf.get_default_graph().get_operations(): 列印模型節點資訊
#     print (op.name)

*GraphDef:這個屬性記錄了tensorflow計算圖上節點的資訊。

執行後生成的檔案如下:
add_model.pb
frozen_model.pb

add_model.pb : 裡面儲存了重輸入層到輸出層這個計算過程的計算圖和相關變數的值,我們得到這個模型後傳入一個輸入,既可以得到一個預估的輸出值

3\CKPT 轉換成 PB格式

通過傳入 CKPT 模型的路徑得到模型的圖和變數資料
通過 import_meta_graph 匯入模型中的圖
通過 saver.restore 從模型中恢復圖中各個變數的資料
通過 graph_util.convert_variables_to_constants 將模型持久化

# coding=UTF-8
import tensorflow as tf
import os.path
import argparse
from tensorflow.python.framework import graph_util

MODEL_DIR = "model/pb"
MODEL_NAME = "frozen_model.pb"

if not tf.gfile.Exists(MODEL_DIR): #建立目錄
    tf.gfile.MakeDirs(MODEL_DIR)

def freeze_graph(model_folder):
    checkpoint = tf.train.get_checkpoint_state(model_folder) #檢查目錄下ckpt檔案狀態是否可用
    input_checkpoint = checkpoint.model_checkpoint_path #得ckpt檔案路徑
    output_graph = os.path.join(MODEL_DIR, MODEL_NAME) #PB模型儲存路徑

    output_node_names = "predictions" #原模型輸出操作節點的名字
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True) #得到圖、clear_devices :Whether or not to clear the device field for an `Operation` or `Tensor` during import.

    graph = tf.get_default_graph() #獲得預設的圖
    input_graph_def = graph.as_graph_def()  #返回一個序列化的圖代表當前的圖

    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint) #恢復圖並得到資料

        print "predictions : ", sess.run("predictions:0", feed_dict={"input_holder:0": [10.0]}) # 測試讀出來的模型是否正確,注意這裡傳入的是輸出 和輸入 節點的 tensor的名字,不是操作節點的名字

        output_graph_def = graph_util.convert_variables_to_constants(  #模型持久化,將變數值固定
            sess,
            input_graph_def,
            output_node_names.split(",") #如果有多個輸出節點,以逗號隔開
        )
        with tf.gfile.GFile(output_graph, "wb") as f: #儲存模型
            f.write(output_graph_def.SerializeToString()) #序列化輸出
        print("%d ops in the final graph." % len(output_graph_def.node)) #得到當前圖有幾個操作節點

        for op in graph.get_operations():
            print(op.name, op.values())

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("model_folder", type=str, help="input ckpt model dir") #命令列解析,help是提示符,type是輸入的型別,
    # 這裡執行程式時需要帶上模型ckpt的路徑,不然會報 error: too few arguments
    aggs = parser.parse_args()
    freeze_graph(aggs.model_folder)
    # freeze_graph("model/ckpt") #模型目錄

In this Tensorflow tutorial, I shall explain:

How does a Tensorflow model look like?
How to save a Tensorflow model?
How to restore a Tensorflow model for prediction/transfer learning?
How to work with imported pretrained models for fine-tuning and modification
This tutorial assumes that you have some idea about training a neural network. Otherwise, please follow this tutorial and come back here.

1.What is a Tensorflow model?:
After you have trained a neural network, you would want to save it for future use and deploying to production. So, what is a Tensorflow model? Tensorflow model primarily contains the network design or graph and values of the network parameters that we have trained. Hence, Tensorflow model has two main files:

a) Meta graph:

This is a protocol buffer which saves the complete Tensorflow graph; i.e. all variables, operations, collections etc. This file has .meta extension.

b) Checkpoint file:

This is a binary file which contains all the values of the weights, biases, gradients and all the other variables saved. This file has an extension .ckpt. However, Tensorflow has changed this from version 0.11. Now, instead of single .ckpt file, we have two files:
mymodel.data-00000-of-00001
mymodel.index

mymodel.data-00000-of-00001
mymodel.index
.data file is the file that contains our training variables and we shall go after it.

Along with this, Tensorflow also has a file named checkpoint which simply keeps a record of latest checkpoint files saved.

So, to summarize, Tensorflow models for versions greater than 0.10 look like this:
while Tensorflow model before 0.11 contained only three files:

inception_v1.meta
inception_v1.ckpt
checkpoint
Now that we know how a Tensorflow model looks like, let’s learn how to save the model.

  1. Saving a Tensorflow model:
    Let’s say, you are training a convolutional neural network for image classification. As a standard practice, you keep a watch on loss and accuracy numbers. Once you see that the network has converged, you can stop the training manually or you will run the training for fixed number of epochs. After the training is done, we want to save all the variables and network graph to a file for future use. So, in Tensorflow, you want to save the graph and values of all the parameters for which we shall be creating an instance of tf.train.Saver() class.
saver = tf.train.Saver()

Remember that Tensorflow variables are only alive inside a session. So, you have to save the model inside a session by calling save method on saver object you just created.

saver.save(sess, 'my-test-model')

Here, sess is the session object, while ‘my-test-model’ is the name you want to give your model. Let’s see a complete example:

import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my_test_model')

# This will save following files in Tensorflow v >= 0.11
# my_test_model.data-00000-of-00001
# my_test_model.index
# my_test_model.meta
# checkpoint

If we are saving the model after 1000 iterations, we shall call save by passing the step count:

saver.save(sess, 'my_test_model',global_step=1000)

This will just append ‘-1000’ to the model name and following files will be created:

my_test_model-1000.index
my_test_model-1000.meta
my_test_model-1000.data-00000-of-00001
checkpoint

Let’s say, while training, we are saving our model after every 1000 iterations, so .meta file is created the first time(on 1000th iteration) and we don’t need to recreate the .meta file each time(so, we don’t save the .meta file at 2000, 3000.. or any other iteration). We only save the model for further iterations, as the graph will not change. Hence, when we don’t want to write the meta-graph we use this:

saver.save(sess, 'my-model', global_step=step,write_meta_graph=False)

If you want to keep only 4 latest models and want to save one model after every 2 hours during training you can use max_to_keep and keep_checkpoint_every_n_hours like this.

#saves a model every 2 hours and maximum 4 latest models are saved.
saver = tf.train.Saver(max_to_keep=4, keep_checkpoint_every_n_hours=2)

Note, if we don’t specify anything in the tf.train.Saver(), it saves all the variables. What if, we don’t want to save all the variables and just some of them. We can specify the variables/collections we want to save. While creating the tf.train.Saver instance we pass it a list or a dictionary of variables that we want to save. Let’s look at an example:

import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver([w1,w2])
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my_test_model',global_step=1000)
  1. Importing a pre-trained model:
    If you want to use someone else’s pre-trained model for fine-tuning, there are two things you need to do:

a) Create the network:

You can create the network by writing python code to create each and every layer manually as the original model. However, if you think about it, we had saved the network in .meta file which we can use to recreate the network using tf.train.import() function like this: saver = tf.train.import_meta_graph(‘my_test_model-1000.meta’)

Remember, import_meta_graph appends the network defined in .meta file to the current graph. So, this will create the graph/network for you but we still need to load the value of the parameters that we had trained on this graph.

b) Load the parameters:

We can restore the parameters of the network by calling restore on this saver which is an instance of tf.train.Saver() class.

with tf.Session() as sess:
  new_saver = tf.train.import_meta_graph('my_test_model-1000.meta')
  new_saver.restore(sess, tf.train.latest_checkpoint('./'))

After this, the value of tensors like w1 and w2 has been restored and can be accessed:

with tf.Session() as sess:    
    saver = tf.train.import_meta_graph('my-model-1000.meta')
    saver.restore(sess,tf.train.latest_checkpoint('./'))
    print(sess.run('w1:0'))
##Model has been restored. Above statement will print the saved value of w1.

So, now you have understood how saving and importing works for a Tensorflow model. In the next section, I have described a practical usage of above to load any pre-trained model.

  1. Working with restored models
    Now that you have understood how to save and restore Tensorflow models, Let’s develop a practical guide to restore any pre-trained model and use it for prediction, fine-tuning or further training. Whenever you are working with Tensorflow, you define a graph which is fed examples(training data) and some hyperparameters like learning rate, global step etc. It’s a standard practice to feed all the training data and hyperparameters using placeholders. Let’s build a small network using placeholders and save it. Note that when the network is saved, values of the placeholders are not saved.
import tensorflow as tf

#Prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1= tf.Variable(2.0,name="bias")
feed_dict ={w1:4,w2:8}

#Define a test operation that we will restore
w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

#Create a saver object which will save all the variables
saver = tf.train.Saver()

#Run the operation by feeding input
print sess.run(w4,feed_dict)
#Prints 24 which is sum of (w1+w2)*b1 

#Now, save the graph
saver.save(sess, 'my_test_model',global_step=1000)

Now, when we want to restore it, we not only have to restore the graph and weights, but also prepare a new feed_dict that will feed the new training data to the network. We can get reference to these saved operations and placeholder variables via graph.get_tensor_by_name() method.

#How to access saved variable/Tensor/placeholders 
w1 = graph.get_tensor_by_name("w1:0")

## How to access saved operation
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

If we just want to run the same network with different data, you can simply pass the new data via feed_dict to the network.

import tensorflow as tf

sess=tf.Session()    
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))


# Now, let's access and create placeholders variables and
# create feed-dict to feed new data

graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}

#Now, access the op that you want to run. 
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

print sess.run(op_to_restore,feed_dict)
#This will print 60 which is calculated 
#using new values of w1 and w2 and saved value of b1. 

What if you want to add more operations to the graph by adding more layers and then train it. Of course you can do that too. See here:

import tensorflow as tf

sess=tf.Session()    
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))


# Now, let's access and create placeholders variables and
# create feed-dict to feed new data

graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}

#Now, access the op that you want to run. 
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

#Add more to the current graph
add_on_op = tf.multiply(op_to_restore,2)

print sess.run(add_on_op,feed_dict)
#This will print 120.

But, can you restore part of the old graph and add-on to that for fine-tuning ? Of-course you can, just access the appropriate operation by graph.get_tensor_by_name() method and build graph on top of that. Here is a real world example. Here we load a vgg pre-trained network using meta graph and change the number of outputs to 2 in the last layer for fine-tuning with new data.

......
......
saver = tf.train.import_meta_graph('vgg.meta')
# Access the graph
graph = tf.get_default_graph()
## Prepare the feed_dict for feeding data for fine-tuning 

#Access the appropriate output for fine-tuning
fc7= graph.get_tensor_by_name('fc7:0')

#use this if you only want to change gradients of the last layer
fc7 = tf.stop_gradient(fc7) # It's an identity function
fc7_shape= fc7.get_shape().as_list()

new_outputs=2
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05))
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs]))
output = tf.matmul(fc7, weights) + biases
pred = tf.nn.softmax(output)

# Now, you run this with fine-tuning data in sess.run()