1. 程式人生 > 程式設計 >Tensorflow實現部分引數梯度更新操作

Tensorflow實現部分引數梯度更新操作

在深度學習中,遷移學習經常被使用,在大資料集上預訓練的模型遷移到特定的任務,往往需要保持模型引數不變,而微調與任務相關的模型層。

本文主要介紹,使用tensorflow部分更新模型引數的方法。

1. 根據Variable scope剔除需要固定引數的變數

def get_variable_via_scope(scope_lst):
  vars = []
  for sc in scope_lst:
    sc_variable = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope=scope)
    vars.extend(sc_variable)
  return vars
 
trainable_vars = tf.trainable_variables()
no_change_scope = ['your_unchange_scope_name']
 
no_change_vars = get_variable_via_scope(no_change_scope)
 
for v in no_change_vars:
  trainable_vars.remove(v)
 
grads,_ = tf.gradients(loss,trainable_vars)
 
optimizer = tf.train.AdamOptimizer(lr)
 
train_op = optimizer.apply_gradient(zip(grads,trainable_vars),global_step=global_step)

2. 使用tf.stop_gradient()函式

在建立Graph過程中使用該函式,非常簡潔地避免了使用scope獲取引數

3. 一個矩陣中部分行或列引數更新

如果一個矩陣,只有部分行或列需要更新引數,其它保持不變,該場景很常見,例如word embedding中,一些預定義的領域相關詞保持不變(使用領域相關word embedding初始化),而另一些通用詞變化。

import tensorflow as tf
import numpy as np
 
def entry_stop_gradients(target,mask):
  mask_h = tf.abs(mask-1)
  return tf.stop_gradient(mask_h * target) + mask * target
 
mask = np.array([1.,1,1])
mask_h = np.abs(mask-1)
 
emb = tf.constant(np.ones([10,5]))
 
matrix = entry_stop_gradients(emb,tf.expand_dims(mask,1))
 
parm = np.random.randn(5,1)
t_parm = tf.constant(parm)
 
loss = tf.reduce_sum(tf.matmul(matrix,t_parm))
grad1 = tf.gradients(loss,emb)
grad2 = tf.gradients(loss,matrix)
print matrix
with tf.Session() as sess:
  print sess.run(loss)
  print sess.run([grad1,grad2])

以上這篇Tensorflow實現部分引數梯度更新操作就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。