Pytorch 快速入門(七)載入預訓練模型初始化網路引數
阿新 • • 發佈:2019-01-26
在預訓練網路的基礎上,修改部分層得到自己的網路,通常我們需要解決的問題包括:
1. 從預訓練的模型載入引數
在PyTorch中,每個Variable資料含有兩個flag(requires_grad和volatile)用於指示是否計算此Variable的梯度。設定requires_grad = False,或者設定volatile=True,即可指示不計算此Variable的梯度
1. 從預訓練的模型載入引數
2. 對新網路兩部分設定不同的學習率,主要訓練自己新增的層
PyTorch提供的預訓練模型
PyTorch定義了幾個常用模型,並且提供了預訓練版本:
- AlexNet: AlexNet variant from the “One weird trick” paper.
- VGG: VGG-11, VGG-13, VGG-16, VGG-19 (with and without batch normalization)
- ResNet: ResNet-18, ResNet-34, ResNet-50, ResNet-101, ResNet-152
- SqueezeNet: SqueezeNet 1.0, and SqueezeNet 1.1
預訓練模型可以通過設定pretrained=True來構建:
eg:
import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
vgg16 = models.vgg16(pretrained=True)
alexnet = models.alexnet(pretrained=True)
squeezenet = models.squeezenet1_0(pretrained=True)
預訓練模型期望的輸入是RGB影象的mini-batch:(batch_size, 3, H, W),並且H和W不能低於224。影象的畫素值必須在範圍[0,1]間,並且用均值mean=[0.485, 0.456, 0.406]和方差std=[0.229, 0.224, 0.225]進行歸一化。
載入預訓練模型
載入引數可以參考apaszke推薦的做法,即刪除與當前model不匹配的key。torch.nn.Module物件有函式static_dict()用於返回包含模組所有狀態的字典,包括引數和快取。鍵是引數名稱或者快取名稱。
函式Module::load_state_dict(state_dict)用state_dict中的狀態值更新模組的狀態值。static_dict中的鍵應該和函式static_dict()返回的字典中的鍵完全一樣。
下面給出載入預訓練的模型的示例:
vgg16 = models.vgg16(pretrained=True) pretrained_dict = vgg16.state_dict() model_dict = model.state_dict() # 1. filter out unnecessary keys pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict model.load_state_dict(model_dict)
不同層設定不同學習率的方法
此部分主要參考PyTorch教程的Autograd machnics部分在PyTorch中,每個Variable資料含有兩個flag(requires_grad和volatile)用於指示是否計算此Variable的梯度。設定requires_grad = False,或者設定volatile=True,即可指示不計算此Variable的梯度
for param in model.parameters():
param.requires_grad = False
注意,在模型測試時,對input_data設定volatile=True,可以節省測試時的視訊記憶體