pytorch-模型儲存和載入
阿新 • • 發佈:2021-06-23
pytorch-模型儲存和載入
目錄載入模型引數和選擇是由儲存的模型資料結構決定,故先要確定儲存模型模型的方法和資料結構
儲存模型
# 模型權重引數 model.state_dict() '''首先說一下 model.state_dict() pytorch 中的 model.state_dict 是一個簡單的python的字典物件,將每一層與它的對應引數建立對映關係.(如model的每一層的weights及偏置等等) 只有那些引數可以訓練的layer才會被儲存到模型的state_dict中,如卷積層,線性層等 state_dict是在定義了model或optimizer之後pytorch自動生成的 ''' # model.state_dict() 其實返回的是一個OrderDict,儲存了網路結構的名字和對應的引數 class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.linear1 = nn.Linear(1, 2) self.linear2 = nn.Linear(2, 1) def forward(self, x): x = self.linear1(x) x = self.linear2(x) return x mode = Net() print(mode.state_dict()) """ OrderedDict([('linear1.weight', tensor([[ 0.8108],[-0.7968]])), ('linear1.bias', tensor([ 0.2680, -0.4772])), ('linear2.weight', tensor([[-0.7066, -0.3334]])), ('linear2.bias', tensor([0.4819]))]) """ print(mode.state_dict().keys()) """ odict_keys(['linear1.weight', 'linear1.bias', 'linear2.weight', 'linear2.bias']) """ for param_tensor in model.state_dict(): #列印 key value字典 print(param_tensor,'\t',model.state_dict()[param_tensor].size()) """ linear1.weight torch.Size([2, 1]) linear1.bias torch.Size([2]) linear2.weight torch.Size([1, 2]) linear2.bias torch.Size([1]) """
# 儲存模型 torch.save(obj, f, pickle_module,pickle_protocol ) """輸入引數 obj 可以是單個值也可以字典、物件 f 要儲存引數的檔案路徑 pickle_module pickle_protocol """ # 1、自定義儲存-工程實踐中常常使用---推薦 state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch } torch.save(model_object, './model.pt') # 2、僅僅是儲存模型權重引數 torch.save(model.state_dict(), PATH) # 3、直接儲存整個模型和模型結構 torch.save(Net,PATH)
載入模型
引數的儲存
torch.save(model_object.state_dict(), 'params.pth') # 模型的載入有模型儲存的資料結構決定 ckpt = torch.load(f, map_location=None) """輸入引數 f file模型檔案 map_location torch.device, 動態地進行記憶體重對映,從不同的裝置上讀取檔案 pickle_module 用於unpickling元資料和物件的模組 pickle_load_args 傳遞給pickle_module.load() 註釋: 如果多塊顯示卡,map_location={'cuda:0':"cuda:1"},指定在2號顯示卡,不使用1號顯示卡 返回引數 字典d 由載入檔案定義 預設情況,dict_keys(['epoch', 'state_dict', 'optimizer', 'best_pred']) """ # 1、針對第一種儲存模型的載入方式 # 載入模型 model=Net() # 載入模型引數 model_CKPT = torch.load(checkpoint_PATH) # 引數各個屬性f model.load_state_dict(model_CKPT['model']) optimizer.load_state_dict(model_CKPT['optimizer']) # 2、針對第二種儲存模型的載入方式 model=Net() # 例項化網路 model_CKPT = torch.load(checkpoint_PATH) # 載入模型引數 model.load_state_dict(model_CKPT) # 針對第三種儲存整個模型的載入方式 model = torch.load(mode_PATH)
部分權重的載入
# 關鍵自定義函式
def intersect_dicts(da, db, exclude=()):
"""輸入引數
da (state_dict) 載入權重的 state_dict
db (state_dict) 載入模型的 state_dict
exclude (list) 不想要的權重 keys()
返回引數
載入的部分權重 (state_dict)
"""
'''
print("exclude",exclude)
for k, v in da.items():
for x in exclude:
if x in k:
print('@ ',x ,k)
if v.shape != db[k].shape:
print('# ', x, k)
'''
return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
案例
# 載入模型
model = Net()
# 載入權重
ckpt=torch.load(weights, map_location=device)
state_dict=ckpt.state_dict()
# state_dict 是一個字典
# state_dict.keys()
# odict_keys(['0.model.0.conv.conv.weight', '0.model.0.conv.conv.bias', '0.model.1.conv.weight', .....])
# 權重取捨處理
state_dict=intersect_dicts(state_dict, model.state_dict(), exclude=exclude)
# 模型載入權重
model.load_state_dict(state_dict, strict=False)
# 最後可以輸出載入了多少個
print('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights))
# output >>> Transferred 498/506 items from yolov5m.pt