PyTorch 介紹 | 儲存和載入模型
阿新 • • 發佈:2022-02-10
本節我們將會看到如何儲存模型狀態、載入和執行模型預測
import torch
import torchvision.models as models
儲存和載入模型權重
PyTorch模型在一個稱為 state_dict
的內部狀態字典內儲存了學習的引數,可以通過 torch.save
實現這一過程。
model = models.vgg16(pretrained=True)
torch.save(model.state_dict(), 'model_weights.pth')
為了載入模型引數,你需要首先建立一個相同模型的實體,然後使用 load_state_dict()
載入引數。
model = models.vgg16() # we do not specify pretrained=True, i.e. do not load default weights model.load_state_dict(torch.load('model_weights.pth')) model.eval()
注意:在推理前,確保呼叫 model.eval()
設定dropout和batch normalization層是評估模式,否則將產生不一致的推斷結果。
使用Shapes儲存和載入模型
當載入模型權重時,我們需要首先初始化模型類,因為該類定義了網路結構。我們可能想將模型權重和該類的結構儲存在一起,在這種情況下,可以將 model
(而不是model.state_dict()
)傳入儲存函式。
torch.save(model, 'model.pth')
載入
model = torch.load('model.pth')
注意:這種方法在序列化模型時使用Python pickle模組,因此,它依賴於載入模型時可用的實際類的定義。