1. 程式人生 > 其它 >tensorflow沒有這個引數_關於TensorFlow動態設定trainable的問題

tensorflow沒有這個引數_關於TensorFlow動態設定trainable的問題

技術標籤:tensorflow沒有這個引數

這個問題是在專案中遇到的一個問題,即“如何在訓練過程中動態的控制哪些引數更新,哪些不更新”。我們知道tensorflow的計算圖是靜態的,tensorflow中定義的tf.Variable時,可以通過trainable屬性控制這個變數是否可以被優化器更新。但是,tf.Variable的trainable屬性是隻讀的,我們無法動態更改這個只讀屬性。在定義tf.Variable時,如果在定義變數時制定了trainable=True,那麼只要這個變數被初始化後,這個trainable就沒法更改了,即使使用tf.placeholder(tf.bool)在訓練時給這個變數傳遞一個引數試圖改變該變數的trainable屬性也是不可以的,會報錯。

那麼如何在訓練時動態的選擇需要更新的和不需要更新的引數呢?我在這裡提供一個思路,這個思路也是在stackoverflow上看到的(連結);另外我還看到一個辦法,但是我沒有嘗試,如果有人有興趣可以試一下是否可行(連結)。

tensorflow將可以更新的引數存在TRAINABLE_VARIABLES中,所以我們只要定義兩個不同的優化器就可以了。每個優化器指定當前更新哪些引數,這樣我們就可以交替更新引數了。

P.S. 這篇部落格也提供了部分思路,但是我覺得他的實現方式較為複雜。

下面我給出實現程式碼(核心程式碼):

x = tf.placeholder(shape=[None,5],dtype=tf.float32,name='x')
y = tf.placeholder(shape=[None,1],dtype=tf.float32,name='y')
with tf.variable_scope('z'):
    z = tf.Variable(tf.zeros(shape=[3,1],dtype=tf.float32),name='z',trainable=True) # 定義中間隱引數z

def dnn(x,z):
    with tf.variable_scope('parameter'):
        w1 = tf.Variable(tf.truncated_normal([5, 3]), dtype=tf.float32, trainable=True, name='w1')
        b1 = tf.Variable(tf.constant(0.001,shape=[3],dtype=tf.float32),trainable=True,name='b1')
        w2 = tf.Variable(tf.truncated_normal([3, 1]), dtype=tf.float32, trainable=True, name='w2')
        b2 = tf.Variable(tf.constant(0.001, shape=[1], dtype=tf.float32), trainable=True, name='b2')
    d1 = tf.add(tf.matmul(x,w1),b1)
    d2 = tf.add(tf.matmul(d1,w2),b2)
    output = tf.add(d2,z)
    return output

y_ = dnn(x,z)
loss = tf.abs(y-y_)
optimzer = tf.train.AdamOptimizer(0.001)
trainable_var_p = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,'parameter')
trainable_var_z = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,'z')
train_op_p = optimzer.minimize(loss,var_list=trainable_var_p)
train_op_z = optimzer.minimize(loss,var_list=trainable_var_z)

with tf.Session() as sess:
    sess.run(init)
    for i in range(100):
        print('第',i,'輪')
        if i%10<5 :
            print("z不更新")
            sess.run(train_op_p,feed_dict={x:batch_x,y:batch_y})
        else:
            print("z更新")
            sess.run(train_op_z, feed_dict={x: batch_x, y: batch_y})