1. 程式人生 > >tensorflow saver 儲存和恢復指定 tensor

tensorflow saver 儲存和恢復指定 tensor

在實踐中經常會遇到這樣的情況:

1, 用簡單的模型預訓練引數

2, 把預訓練的引數匯入複雜的模型後訓練複雜的模型

這時就產生一個問題:

                如何載入預訓練的引數。

下面就是我的總結。

為了方便說明,做一個假設:               簡單的模型只有一個卷基層,複雜模型有兩個。

                卷積層的實現程式碼如下:

import tensorflow as tf
# PS:本篇的重擔是saver,不過為了方便閱讀還是說明下引數
# 引數
# name:建立卷基層的程式碼這麼多,必須要函式化,而為了防止變數衝突就需要用tf.name_scope
# input_data:輸入資料
# width, high:卷積小視窗的寬、高
# deep_before, deep_after:卷積前後的神經元數量
# stride:卷積小視窗的移動步長
def make_conv(name, input_data, width, high, deep_before,deep_after, stride, padding_type='SAME'):
    global parameters
    with tf.name_scope(name) asscope:
        weights =tf.Variable(tf.truncated_normal([width, high, deep_before, deep_after],
            dtype=tf.float32,stddev=0.01), trainable=True, name='weights')
        biases =tf.Variable(tf.constant(0.1, shape=[deep_after]), trainable=True, name='biases')
        conv =tf.nn.conv2d(input_data, weights, [1, stride, stride, 1], padding=padding_type)
        bias = tf.add(conv,biases)
        bias = batch_norm(bias,deep_after, 1) # batch_norm是自己寫的batchnorm函式
        conv =tf.maximum(0.1*bias, bias)
        return conv

簡單的預訓練模型就下面一句話

conv1 =make_conv('simple-conv1', images, 3, 3, 3, 32, 1)

複雜的模型是兩個卷基層,如下:

conv1 = make_conv('complex-conv1',images, 3, 3, 3, 32, 1)
pool1= make_max_pool('layer1-pool1', conv1, 2, 2)
conv2= make_conv('complex-conv2', pool1, 3, 3, 32, 64, 1)

這時簡簡單單的在預訓練模型中:

saver = tf.train.Saver()
with tf.Session() as sess:
    saver.save(sess,'model.ckpt')

就不行了,因為:

    1,如果你在預訓練模型中使用下面的話列印所有tensor

all_v =tf.global_variables()
for i in all_v:  print  i

    會發現tensor的名字不是weights和biases,而是'simple-conv1/weights和'simple-conv1/biases,如下:

        <tf.Variable'simple-conv1/weights:0' shape=(3, 3, 3, 32) dtype=float32_ref>

        <tf.Variable'simple-conv1/biases:0' shape=(32,) dtype=float32_ref>

        <tf.Variable 'simple-conv1/Variable:0' shape=(32,)dtype=float32_ref>

        <tf.Variable 'simple-conv1/Variable_1:0' shape=(32,)dtype=float32_ref>

        <tf.Variable 'simple-conv1/Variable_2:0' shape=(32,)dtype=float32_ref>

        <tf.Variable 'simple-conv1/Variable_3:0' shape=(32,)dtype=float32_ref>

    同理,在複雜模型中就是complex-conv1/weights和complex-conv1/biases,這是對不上號的。

     2,預訓練模型中只有1個卷積層,而複雜模型中有兩個,而tensorflow預設會從模型檔案('model.ckpt')中找所有的“可訓練的”tensor,找不到會報錯。

解決方法:

    1,在預訓練模型中定義全域性變數

parm_dict={}

    並在“return conv”上面新增下面兩行

parm_dict['complex-conv1/weights']= weights
parm_dict['complex-conv1/']= biases

    然後在定義saver時使用下面這句話:

saver= tf.train.Saver(parm_dict)

    這樣儲存後的模型檔案就對應到複雜模型上了。

    2,在複雜模型中定義全域性變數

parameters= []

    並在“return conv”上面新增下面行

parameters+= [weights, biases]

    然後判斷如果是第二個卷積層就不更新parameters。

    接著在定義saver時使用下面這句話:

saver= tf.train.Saver(parameters)

    這樣就可以告訴saver,只需要從模型檔案中找weights和biases,而那些什麼complex-conv1/Variable~ complex-conv1/Variable_3統統滾一邊去(上面紅色部分)。

    最後使用下面的程式碼載入就可以了                              

with tf.Session() as sess:
    ckpt= tf.train.get_checkpoint_state('.')
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess,ckpt.model_checkpoint_path)
    else:
        print '  no saver.'
        exit()