1. 程式人生 > 其它 >pytorch筆記(二)——模型的儲存與載入

pytorch筆記(二)——模型的儲存與載入

技術標籤:pytorch

1.儲存和載入模型

# 模型儲存
model = ModelClass(*args, **kwargs)
torch.save(model, 'model.ckpt')

# 模型載入
model = torch.load(PATH)

  儲存整個網路的的結構資訊和模型引數資訊,save的物件是網路net。載入時無需再定義網路。

2.儲存和載入模型引數

# 模型引數儲存
model = ModelClass(*args, **kwargs)
torch.save(model.state_dict(), 'params.ckpt')

# 模型引數載入
model =
ModelClass(*args, **kwargs) model.load_state_dict(torch.load('params.ckpt'))

  只儲存神經網路的訓練模型引數,save的物件是net.state_dict()。載入模型引數前需要自己定義網路,並且其中的引數名稱與結構要與儲存的模型中的一致。

使用該方式常見問題:
1.模型引數引數名稱不一致
在這裡插入圖片描述
  在上圖中儲存的模型引數比現在的模型多了flat_w這一部分,如果直接載入儲存的模型引數就會報下面的錯誤。

RuntimeError: Error(s) in loading state_dict for CNNMnist:
Unexpected key(s) in state_dict: "flat_w".

解決辦法:
  建立一個新的字典物件OrderedDict(),將需要的引數賦值到新建的字典物件中,然後載入新建的字典物件

print("現在的模型引數名稱")
model = CNNMnist(args=args).to(args.device)
for (k, v) in model.state_dict().items():
    print(k)

print("儲存的模型引數名稱")
params = torch.load("./model/model0.ckpt"
) for (k, v) in params.items(): print(k) new_state_dict = OrderedDict() for i, (k, v) in enumerate(params.items()): if i != 0: name = k new_state_dict[name] = v model.load_state_dict(new_state_dict)

2.儲存引數和載入引數的torch版本不同
  如果儲存模型時,torch的版本為1.6,而載入引數時的torch版本小於。那麼載入引數是會出現以下錯誤

raise RuntimeError("{} is a zip archive (did you mean to use torch.jit.load()?)".format(f.name))

  這是因為在torch1.6版本中torch.save儲存的引數是zip格式的,所以載入時出現錯誤。
解決辦法:
1.在torch1.6使用torch.save儲存引數是加上_use_new_zipfile_serialization=False這個引數,即

torch.save(model.state_dict(), 'params.ckpt', _use_new_zipfile_serialization=False)

2.將載入引數時的torch版本升到1.6