pytorch筆記(二)——模型的儲存與載入
阿新 • • 發佈:2021-01-08
技術標籤:pytorch
1.儲存和載入模型
# 模型儲存
model = ModelClass(*args, **kwargs)
torch.save(model, 'model.ckpt')
# 模型載入
model = torch.load(PATH)
儲存整個網路的的結構資訊和模型引數資訊,save的物件是網路net。載入時無需再定義網路。
2.儲存和載入模型引數
# 模型引數儲存
model = ModelClass(*args, **kwargs)
torch.save(model.state_dict(), 'params.ckpt')
# 模型引數載入
model = ModelClass(*args, **kwargs)
model.load_state_dict(torch.load('params.ckpt'))
只儲存神經網路的訓練模型引數,save的物件是net.state_dict()。載入模型引數前需要自己定義網路,並且其中的引數名稱與結構要與儲存的模型中的一致。
使用該方式常見問題:
1.模型引數引數名稱不一致
在上圖中儲存的模型引數比現在的模型多了flat_w這一部分,如果直接載入儲存的模型引數就會報下面的錯誤。
RuntimeError: Error(s) in loading state_dict for CNNMnist:
Unexpected key(s) in state_dict: "flat_w".
解決辦法:
建立一個新的字典物件OrderedDict(),將需要的引數賦值到新建的字典物件中,然後載入新建的字典物件
print("現在的模型引數名稱")
model = CNNMnist(args=args).to(args.device)
for (k, v) in model.state_dict().items():
print(k)
print("儲存的模型引數名稱")
params = torch.load("./model/model0.ckpt" )
for (k, v) in params.items():
print(k)
new_state_dict = OrderedDict()
for i, (k, v) in enumerate(params.items()):
if i != 0:
name = k
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
2.儲存引數和載入引數的torch版本不同
如果儲存模型時,torch的版本為1.6,而載入引數時的torch版本小於。那麼載入引數是會出現以下錯誤
raise RuntimeError("{} is a zip archive (did you mean to use torch.jit.load()?)".format(f.name))
這是因為在torch1.6版本中torch.save儲存的引數是zip格式的,所以載入時出現錯誤。
解決辦法:
1.在torch1.6使用torch.save儲存引數是加上_use_new_zipfile_serialization=False這個引數,即
torch.save(model.state_dict(), 'params.ckpt', _use_new_zipfile_serialization=False)
2.將載入引數時的torch版本升到1.6