【pytorch】載入模型出現的bug
阿新 • • 發佈:2018-12-14
在模型訓練完後再進行測試載入模型後出現bug,顯示如下錯誤
據瞭解是由於pytorch版本導致的錯誤,可能與自己訓練階段保持的模型方式有關,訓練階段儲存方式如下:
解決方案如下:
方法一:
generator.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(generator_1_10.pth).items()})
實際上就是將load進行的權重的有序字典裡面的鍵值前面的的7個字元’module.'去掉。載入進行的權重有序字典如下圖所示:
鍵就是每層的權重或者 bias 的名稱,value就是其具體的張量值。
方法二:重新新建個有序字典:
from collections import OrderedDict # new_state_dict = OrderedDict() # for k, v in a.items(): # name=k[7:] # reduce `module.` # new_state_dict[name] = v # # load params # # model.load_state_dict(new_state_dict) # model.load_state_dict(new_state_dict)
顯然方法一更簡潔明瞭。