1. 程式人生 > 程式設計 >tensorflow獲取預訓練模型某層引數並賦值到當前網路指定層方式

tensorflow獲取預訓練模型某層引數並賦值到當前網路指定層方式

已經有了一個預訓練的模型,我需要從其中取出某一層,把該層的weights和biases賦值到新的網路結構中,可以使用tensorflow中的pywrap_tensorflow(用來讀取預訓練模型的引數值)結合Session.assign()進行操作。

這種需求即預訓練模型可能為單分支網路,當前網路為多分支,我需要把單分支A複用到到多個分支去(B,C,D)。

先匯入對應的工具包

from tensorflow.python import pywrap_tensorflow

接下來的操作在一個tf.Session中進行

reader = pywrap_tensorflow.NewCheckpointReader(pre_train_model_path)

# 獲取當前圖可訓練變數
trainable_variables = tf.trainable_variables()
# 需要賦值的當前網路層變數,這裡只是隨便起的名字。
restore_v_target_name = "fc_target"
# 需要的預訓練模型中的某層的名字
restore_v_source_name = "fc_source"
for v in trainable_variables:
  if restore_v_target_name == v.name:
   # 回覆weights和biases
    sess.run(
      tf.assign(v,reader.get_tensor(restore_v_source_name + "/weights"))) if "weights" in v.name else sess.run(
      tf.assign(v,reader.get_tensor(restore_v_source_name + "/biases")))

以上這篇tensorflow獲取預訓練模型某層引數並賦值到當前網路指定層方式就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。