pytorch 預訓練模型 最後幾層的修改方法
阿新 • • 發佈:2021-01-06
技術標籤:pytorch網路模型pytorch神經網路自動駕駛深度學習
pytorch 預訓練模型 最後幾層的修改方法
方法1
model = net() #自己定義的模型,但要保證前面儲存的層和自定義的模型中的層一致
state_dict = torch.load('xxx/xxx.pth')# 模型
keys = []
for k,v in state_dict.items():
if k.startswith('_fc'):
continue
keys. append(k) #將‘_fc’開頭的key過濾掉,留下需要的
new_dict = {k:state_dict[k] for k in keys} # 重新儲存需要的的模型權重
model.state_dict().update(new_dict) # 模型更新現有的dict
# 載入我們真正需要的state_dict
model.load_state_dict(model_dict)
print(model) #列印檢視
https://blog.csdn.net/qq_36076233/article/details/107793069
方法2
state_dict = torch.load('xxx/xxx.pth')#模型
# 直接丟棄不需要的模組
state_dict.pop('_fc.weight')
state_dict.pop('_fc.bias')
new_model = model.load_state_dict(state_dict, strict=False)
print(new_model)