1. 程式人生 > 程式設計 >tensorflow模型儲存、載入之變數重新命名例項

tensorflow模型儲存、載入之變數重新命名例項

話不多說,幹就完了。

變數重新命名的用處?

簡單定義:簡單來說就是將模型A中的引數parameter_A賦給模型B中的parameter_B

使用場景:當需要使用已經訓練好的模型引數,尤其是使用別人訓練好的模型引數時,往往別人模型中的引數命名方式與自己當前的命名方式不同,所以在載入模型引數時需要對引數進行重新命名,使得程式碼更簡潔易懂。

實現方法:

1)、模型儲存

import os
import tensorflow as tf
 
weights = tf.Variable(initial_value=tf.truncated_normal(shape=[1024,2],mean=0.0,stddev=0.1),dtype=tf.float32,name="weights")
biases = tf.Variable(initial_value=tf.zeros(shape=[2]),name="biases")
 
weights_2 = tf.Variable(initial_value=weights.initialized_value(),name="weights_2")
 
# saver checkpoint
if os.path.exists("checkpoints") is False:
  os.makedirs("checkpoints")
 
saver = tf.train.Saver()
with tf.Session() as sess:
  init_op = [tf.global_variables_initializer()]
  sess.run(init_op)
  saver.save(sess=sess,save_path="checkpoints/variable.ckpt")

2)、模型載入(變數名稱保持不變)

import tensorflow as tf
from matplotlib import pyplot as plt
import os
 
current_path = os.path.dirname(os.path.abspath(__file__))
 
def restore_variable(sess):
  # need not initilize variable,but need to define the same variable like checkpoint
  weights = tf.Variable(initial_value=tf.truncated_normal(shape=[1024,name="weights")
  biases = tf.Variable(initial_value=tf.zeros(shape=[2]),name="biases")
 
  weights_2 = tf.Variable(initial_value=weights.initialized_value(),name="weights_2")
 
  saver = tf.train.Saver()
 
  ckpt_path = os.path.join(current_path,"checkpoints","variable.ckpt")
  saver.restore(sess=sess,save_path=ckpt_path)
 
  weights_val,weights_2_val = sess.run(
    [
      tf.reshape(weights,shape=[2048]),tf.reshape(weights_2,shape=[2048])
    ]
  )
 
  plt.subplot(1,2,1)
  plt.scatter([i for i in range(len(weights_val))],weights_val)
  plt.subplot(1,2)
  plt.scatter([i for i in range(len(weights_2_val))],weights_2_val)
  plt.show()
 
 
if __name__ == '__main__':
  with tf.Session() as sess:
    restore_variable(sess)

3)、模型載入(變數重新命名)

import tensorflow as tf
from matplotlib import pyplot as plt
import os
 
current_path = os.path.dirname(os.path.abspath(__file__))
 
 
def restore_variable_renamed(sess):
  conv1_w = tf.Variable(initial_value=tf.truncated_normal(shape=[1024,name="conv1_w")
  conv1_b = tf.Variable(initial_value=tf.zeros(shape=[2]),name="conv1_b")
 
  conv2_w = tf.Variable(initial_value=conv1_w.initialized_value(),name="conv2_w")
 
  # variable named 'weights' in ckpt assigned to current variable conv1_w
  # variable named 'biases' in ckpt assigned to current variable conv1_b
  # variable named 'weights_2' in ckpt assigned to current variable conv2_w
  saver = tf.train.Saver({
    "weights": conv1_w,"biases": conv1_b,"weights_2": conv2_w
  })
 
  ckpt_path = os.path.join(current_path,save_path=ckpt_path)
 
  conv1_w__val,conv2_w__val = sess.run(
    [
      tf.reshape(conv1_w,tf.reshape(conv2_w,1)
  plt.scatter([i for i in range(len(conv1_w__val))],conv1_w__val)
  plt.subplot(1,2)
  plt.scatter([i for i in range(len(conv2_w__val))],conv2_w__val)
  plt.show()
 
 
if __name__ == '__main__':
  with tf.Session() as sess:
    restore_variable_renamed(sess)

總結:

# 之前模型中叫 'weights'的變數賦值給當前的conv1_w變數

# 之前模型中叫 'biases' 的變數賦值給當前的conv1_b變數

# 之前模型中叫 'weights_2'的變數賦值給當前的conv2_w變數

saver = tf.train.Saver({

"weights": conv1_w,

"biases": conv1_b,

"weights_2": conv2_w

})

以上這篇tensorflow模型儲存、載入之變數重新命名例項就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。