【4】TensorFlow光速入門-儲存模型及載入模型並使用
本文地址:https://www.cnblogs.com/tujia/p/13862360.html
系列文章:
【1】TensorFlow光速入門-tensorflow開發基本流程
【2】TensorFlow光速入門-資料預處理(得到資料集)
【4】TensorFlow光速入門-儲存模型及載入模型並使用
【6】TensorFlow光速入門-python模型轉換為tfjs模型並使用
一、儲存模型
建立一個目錄
!mkdir /tf/saved_model
注:jupyter程式碼塊前面加一個!號表示,這是shell命令,不是程式碼;
儲存模型
model.save('/tf/saved_model/wnw')
儲存模型的其他引數及操作,看這裡https://tensorflow.google.cn/api_docs/python/tf/keras/Model#save
二、載入模型
import tensorflow as tf from tensorflow import keras import numpy as np from IPython import display import random # 載入模型 model = keras.models.load_model('/tf/saved_model/wnw') # 看一下模型的結構 model.summary() # 隨便找點圖片 all_image_paths = [] data_root = pathlib.Path('/tf/datasets/wnw') for item in data_root.rglob('*.jpg'): all_image_paths.append(str(item)) print(len(all_image_paths)) # 隨機選取一張圖片 img_path = random.choice(all_image_paths) print(img_path) # 把圖片處理成需要的tensorimage = tf.io.read_file(img_path) image = tf.image.decode_image(image, channels=1) image = tf.image.resize(image, (100, 100)) image /= 255 print(image.shape) # 預測只支援批量操作,我們給單張圖片再加一維 images = (np.expand_dims(image, 0)) print(images.shape) # 預測 predictions = model.predict(images) # 列印結果 label_names = ['other', 'watch'] label = np.argmax(predictions[0]) print(label_names[label]) # 把圖片也打印出來,看一下預測效果對不對 display.display(display.Image(img_path, width=200, height=200))
注:
用於預測的圖片資料要和訓練的圖片資料保持一致:
簡單來說,訓練不一定要100*100的灰圖,我可以是80*80的灰圖或彩圖,都沒關係。
重要的是,用使用模型的時候,要先把預測資料轉換成訓練集資料一樣的格式
重點:
model.save https://tensorflow.google.cn/api_docs/python/tf/keras/Model#save
keras.models.load_model https://tensorflow.google.cn/api_docs/python/tf/keras/models/load_model
至此,我們已經可以載入並使用模型了。我們可以用python封裝程式成web服務api,以供呼叫。不過像圖片分類這一類,頻繁的拍照上傳圖片呼叫api也不好。
下一節,我們先整理一下圖片分類的完整程式碼,然後下下節,我們再說一下怎樣使用tfjs直接載入模型(不需要調python服務)
【6】TensorFlow光速入門-python模型轉換為tfjs模型並使用
本文連結:https://www.cnblogs.com/tujia/p/13862360.html
完。