Tensorflow如何儲存、讀取model (即利用訓練好的模型測試新資料的準確度)
阿新 • • 發佈:2018-11-20
目標:
cnn2d.py | cnn2d_test.py |
訓練網路,並儲存網路模型 | 讀取網路,用測試集測試準確度 |
直接貼程式碼:(只貼了相關部分,瀏覽完整程式碼請到GitHub)
1. cnn2d.py
import tensorflow as tf import numpy as np from sklearn import metrics print("### Process1 --- data load ###") # 讀取資料 print("### Process2 --- data spilt ###") # 形成訓練集和驗證集 # 定義 # ··· X = tf.placeholder(tf.float32, (None, seg_height, seg_len, num_channels), name='X') Y = tf.placeholder(tf.float32, (None, num_labels), name='Y') # 注意name='X'和name='Y' # 網路結構 # ··· # loss training等定義 # ··· accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name="accuracy") # 注意name="accuracy" with tf.Session() as session: # ··· saver = tf.train.Saver() tf.train.Saver().save(session, "./model/HAR-UCI_model_1") # Epoch: 100 Training Loss: 0.11656471 Training Accuracy: 0.9664903 # Epoch: 100 Valid Accuracy: 0.96754116 # Epoch: 100 Test Accuracy: 0.9321344 # ### Save model_1 successfully ###
雖然只給X, Y, accuracy命名,但網路其餘結構、引數均自動分配了名字並儲存在./model/中。
./model/中儲存的檔案:
只給X, Y, accuracy命名是因為在下面一個程式中只用到了這三個引數。
2. cnn2d_test.py
import tensorflow as tf import numpy as np # 讀取測試集test_x和test_y # ··· saver = tf.train.import_meta_graph("./model/HAR-UCI_model_1.meta") with tf.Session() as session: saver.restore(session, tf.train.latest_checkpoint("./model/")) graph = tf.get_default_graph() feed_dict = {"X:0": test_x, "Y:0": test_y} acc = graph.get_tensor_by_name("accuracy:0") test_acc = session.run(acc, feed_dict=feed_dict) print("Test Accuracy:", test_acc) # Test Accuracy: 0.9321344