網絡訓練細節
阿新 • • 發佈:2019-03-09
res cnblogs 嚴格 調用 更新 mod tor item alexnet
https://www.cnblogs.com/wmlj/p/9917827.html
經典網絡的加載和初始化:pytorch中自帶幾種常用的深度學習網絡預訓練模型,torchvision.models包中包含alexnet、densenet、inception、resnet、squeezenet、vgg等常用網絡結構,並且提供了預訓練模型,可通過調用來讀取網絡結構和預訓練模型(模型參數)。往往為了加快學習進度,訓練的初期直接加載pretrain模型中預先訓練好的參數。加載model如下所示:
import torchvision.models as models #加載網絡結構和預訓練參數#參數pretrained在默認情況下是False,表示只加載網絡結構而不加載預訓練參數來初始化 resnet34 = models.resnet34(pretrained=True) #打印網絡結構 print(resnet34) resnet18.load_state_dict(torch.load(path_params.pkl))#其中,path_params.pkl為預訓練模型參數的保存路徑。加載預先下載好的預訓練參數到resnet18,用預訓練模型的參數初始化resnet18的層,此時resnet18發生了改變。調用model的load_state_dict方法用預訓練的模型參數來初始化自己定義的新網絡結構,這個方法就是PyTorch中通用的用一個模型的參數初始化另一個模型的層的操作。load_state_dict方法還有一個重要的參數是strict,該參數默認是True,表示預訓練模型的層和自己定義的網絡結構層嚴格對應相等(比如層名和維度)。故,當新定義的網絡(model_dict)和預訓練網絡(pretrained_dict)的層名不嚴格相等時,需要先將pretrained_dict裏不屬於model_dict的鍵剔除掉 :pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} ,再用預訓練模型參數更新model_dict,最後用load_state_dict方法初始化自己定義的新網絡結構。 print resnet18 #打印的還是網絡結構 註意: cnn = resnet18.load_state_dict(torch.load( path_params.pkl )) #是錯誤的,這樣cnn將是nonetype pre_dict = resnet18.state_dict() #按鍵值對將模型參數加載到pre_dictprint for k, v in pre_dict.items(): # 打印模型參數 for k, v in pre_dict.items(): print k #打印模型每層命名 #model是自己定義好的新網絡模型,將pretrained_dict和model_dict中命名一致的層加入pretrained_dict(包括參數)。 pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
網絡訓練細節