1. 程式人生 > >Tensorflow如何儲存、讀取model (即利用訓練好的模型測試新資料的準確度)

Tensorflow如何儲存、讀取model (即利用訓練好的模型測試新資料的準確度)

目標:

cnn2d.py cnn2d_test.py
訓練網路,並儲存網路模型 讀取網路,用測試集測試準確度

直接貼程式碼:(只貼了相關部分,瀏覽完整程式碼請到GitHub

1. cnn2d.py

import tensorflow as tf
import numpy as np
from sklearn import metrics

print("### Process1 --- data load ###")
# 讀取資料
print("### Process2 --- data spilt ###")
# 形成訓練集和驗證集

# 定義
# ···
X = tf.placeholder(tf.float32, (None, seg_height, seg_len, num_channels), name='X')
Y = tf.placeholder(tf.float32, (None, num_labels), name='Y')
# 注意name='X'和name='Y'

# 網路結構
# ···

# loss training等定義
# ···

accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name="accuracy")
# 注意name="accuracy"

with tf.Session() as session:
    # ···
    saver = tf.train.Saver()
    tf.train.Saver().save(session, "./model/HAR-UCI_model_1")

# Epoch:  100  Training Loss:  0.11656471  Training Accuracy:  0.9664903
# Epoch:  100 Valid Accuracy: 0.96754116
# Epoch:  100 Test Accuracy: 0.9321344
# ### Save model_1 successfully ###

雖然只給X, Y, accuracy命名,但網路其餘結構、引數均自動分配了名字並儲存在./model/中。

./model/中儲存的檔案:

只給X, Y, accuracy命名是因為在下面一個程式中只用到了這三個引數。

2. cnn2d_test.py

import tensorflow as tf
import numpy as np

# 讀取測試集test_x和test_y
# ···

saver = tf.train.import_meta_graph("./model/HAR-UCI_model_1.meta")
with tf.Session() as session:
    saver.restore(session, tf.train.latest_checkpoint("./model/"))
    graph = tf.get_default_graph()
    feed_dict = {"X:0": test_x, "Y:0": test_y}
    acc = graph.get_tensor_by_name("accuracy:0")
    test_acc = session.run(acc, feed_dict=feed_dict)
    print("Test Accuracy:", test_acc)

# Test Accuracy: 0.9321344