1. 程式人生 > 其它 >PyTorch 介紹 | 儲存和載入模型

PyTorch 介紹 | 儲存和載入模型

本節我們將會看到如何儲存模型狀態、載入和執行模型預測

import torch
import torchvision.models as models

儲存和載入模型權重

PyTorch模型在一個稱為 state_dict 的內部狀態字典內儲存了學習的引數,可以通過 torch.save實現這一過程。

model = models.vgg16(pretrained=True)
torch.save(model.state_dict(), 'model_weights.pth')

為了載入模型引數,你需要首先建立一個相同模型的實體,然後使用 load_state_dict()載入引數。

model = models.vgg16() # we do not specify pretrained=True, i.e. do not load default weights
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()

注意:在推理前,確保呼叫 model.eval() 設定dropout和batch normalization層是評估模式,否則將產生不一致的推斷結果。

使用Shapes儲存和載入模型

當載入模型權重時,我們需要首先初始化模型類,因為該類定義了網路結構。我們可能想將模型權重和該類的結構儲存在一起,在這種情況下,可以將 model (而不是model.state_dict())傳入儲存函式。

torch.save(model, 'model.pth')

載入

model = torch.load('model.pth')

注意:這種方法在序列化模型時使用Python pickle模組,因此,它依賴於載入模型時可用的實際類的定義。

相關教程

Saving and Loading a General Checkpoint in PyTorch