tensorflow基礎(1)變數的建立、初始化、儲存與載入
阿新 • • 發佈:2019-02-01
廢話就不多說了,直接上乾貨。
1.變數的建立
tensoflow建立變數使用tf.Variable();需要指明變數的形狀
b = tf.Variable(tf.zeros([1]))
W = tf.Variable(tf.random_uniform([1, 2], -1.0, 1.0))
如這裡的w,b就是所要建立的變數。
2.初始化
變數的初始化,需要在變數操作執行前執行。
# 初始化變數
init = tf.global_variables_initializer()
.....
sess.run(init)#執行初始化操作
tf.global_variables_initializer()函式初始化了所有的變數。
3.變數的儲存與載入
這裡以一個例項為具體的模板進行講解。
#coding=utf-8
import tensorflow as tf
import numpy as np
import os
#判斷模型儲存路徑是否存在,不存在就建立
#if not os.path.exists('tem/'):
# os.mkdir('tem/')
# 使用 NumPy 生成假資料(phony data), 總共 100 個點.
x_data = np.float32(np.random.rand(2, 100)) # 隨機輸入
y_data = np.dot([0.100, 0.200], x_data) + 0.300
# 構造一個線性模型
#
b = tf.Variable(tf.zeros([1]))
W = tf.Variable(tf.random_uniform([1, 2], -1.0, 1.0))
y = tf.matmul(W, x_data) + b
# 最小化方差
loss = tf.reduce_mean(tf.square(y - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train_op = optimizer.minimize(loss)
# 初始化變數
#merged_summary_op = tf.summary.merge_all()
init = tf.global_variables_initializer()
saver = tf.train.Saver()
# 啟動圖 (graph)
with tf.Session() as sess:
#summary_writer = #tf.summary.FileWriter('tem/mnist_logs', sess.graph)
#if p
#sess.run(init)
# 擬合平面
#path =os.path.join("", "tem/model.ckpt")
ckpt = tf.train.get_checkpoint_state('tem')
if ckpt != None:
print(ckpt.model_checkpoint_path)
saver.restore(sess, ckpt.model_checkpoint_path)
else:
sess.run(init)
#saver.restore(sess,path)
path = os.path.join("", "tem/model.ckpt")
for step in range(0, 201):
sess.run(train_op)
#summary_str = #sess.run(merged_summary_op,feed_dict=#{x_data:x_data,y_data:y_data})
#summary_writer.add_summary(summary_str, step)
#summary_writer.flush()
if step % 20 == 0:
print (step, sess.run(W), sess.run(b))
saver.save(sess, path,global_step=step)
儲存和恢復模型的方法是使用tf.train.Saver物件,預設儲存所有變數,但可以手動傳入要儲存的變數。
saver.save(session,path,global_step)是儲存模型,傳入的是sess,儲存的路徑,以及global_step=step,且必須先建立tem資料夾。
恢復模型的方法類似。saver.restore();
統一的框架,用於解決要麼存在模型,要麼沒有模型進行執行初始化。
ckpt = tf.train.get_checkpoint_state('tem')
if ckpt != None:
print(ckpt.model_checkpoint_path)
saver.restore(sess, ckpt.model_checkpoint_path)
else:
sess.run(init)
恩,純屬個人見解,寫的不好,請給予批評指正。