tensorflow模型儲存、載入之變數重新命名例項
阿新 • • 發佈:2020-01-23
話不多說,幹就完了。
變數重新命名的用處?
簡單定義:簡單來說就是將模型A中的引數parameter_A賦給模型B中的parameter_B
使用場景:當需要使用已經訓練好的模型引數,尤其是使用別人訓練好的模型引數時,往往別人模型中的引數命名方式與自己當前的命名方式不同,所以在載入模型引數時需要對引數進行重新命名,使得程式碼更簡潔易懂。
實現方法:
1)、模型儲存
import os import tensorflow as tf weights = tf.Variable(initial_value=tf.truncated_normal(shape=[1024,2],mean=0.0,stddev=0.1),dtype=tf.float32,name="weights") biases = tf.Variable(initial_value=tf.zeros(shape=[2]),name="biases") weights_2 = tf.Variable(initial_value=weights.initialized_value(),name="weights_2") # saver checkpoint if os.path.exists("checkpoints") is False: os.makedirs("checkpoints") saver = tf.train.Saver() with tf.Session() as sess: init_op = [tf.global_variables_initializer()] sess.run(init_op) saver.save(sess=sess,save_path="checkpoints/variable.ckpt")
2)、模型載入(變數名稱保持不變)
import tensorflow as tf from matplotlib import pyplot as plt import os current_path = os.path.dirname(os.path.abspath(__file__)) def restore_variable(sess): # need not initilize variable,but need to define the same variable like checkpoint weights = tf.Variable(initial_value=tf.truncated_normal(shape=[1024,name="weights") biases = tf.Variable(initial_value=tf.zeros(shape=[2]),name="biases") weights_2 = tf.Variable(initial_value=weights.initialized_value(),name="weights_2") saver = tf.train.Saver() ckpt_path = os.path.join(current_path,"checkpoints","variable.ckpt") saver.restore(sess=sess,save_path=ckpt_path) weights_val,weights_2_val = sess.run( [ tf.reshape(weights,shape=[2048]),tf.reshape(weights_2,shape=[2048]) ] ) plt.subplot(1,2,1) plt.scatter([i for i in range(len(weights_val))],weights_val) plt.subplot(1,2) plt.scatter([i for i in range(len(weights_2_val))],weights_2_val) plt.show() if __name__ == '__main__': with tf.Session() as sess: restore_variable(sess)
3)、模型載入(變數重新命名)
import tensorflow as tf from matplotlib import pyplot as plt import os current_path = os.path.dirname(os.path.abspath(__file__)) def restore_variable_renamed(sess): conv1_w = tf.Variable(initial_value=tf.truncated_normal(shape=[1024,name="conv1_w") conv1_b = tf.Variable(initial_value=tf.zeros(shape=[2]),name="conv1_b") conv2_w = tf.Variable(initial_value=conv1_w.initialized_value(),name="conv2_w") # variable named 'weights' in ckpt assigned to current variable conv1_w # variable named 'biases' in ckpt assigned to current variable conv1_b # variable named 'weights_2' in ckpt assigned to current variable conv2_w saver = tf.train.Saver({ "weights": conv1_w,"biases": conv1_b,"weights_2": conv2_w }) ckpt_path = os.path.join(current_path,save_path=ckpt_path) conv1_w__val,conv2_w__val = sess.run( [ tf.reshape(conv1_w,tf.reshape(conv2_w,1) plt.scatter([i for i in range(len(conv1_w__val))],conv1_w__val) plt.subplot(1,2) plt.scatter([i for i in range(len(conv2_w__val))],conv2_w__val) plt.show() if __name__ == '__main__': with tf.Session() as sess: restore_variable_renamed(sess)
總結:
# 之前模型中叫 'weights'的變數賦值給當前的conv1_w變數
# 之前模型中叫 'biases' 的變數賦值給當前的conv1_b變數
# 之前模型中叫 'weights_2'的變數賦值給當前的conv2_w變數
saver = tf.train.Saver({
"weights": conv1_w,
"biases": conv1_b,
"weights_2": conv2_w
})
以上這篇tensorflow模型儲存、載入之變數重新命名例項就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。