動態匯入模組,載入預訓練模型,nn.Sequential函式裡面必須是a Module subclass,不能是一個列表或者是其他的迭代器、生成器,雖然這裡麵包含了Module的子類
阿新 • • 發佈:2018-12-01
class RES(nn.Module): def __init__(self): super(RES, self).__init__() self.conv1=nn.Conv2d(3,64,kernel_size=7,stride=2,padding=3,bias=False) self.bn1=nn.BatchNorm2d(64) self.relu=nn.ReLU(inplace=True) self.maxpool=nn.MaxPool2d(kernel_size=3,stride=2,padding=1) self.conv2=nn.Conv2d(64,128,kernel_size=7,stride=2,padding=3,bias=False) self.bn2=nn.BatchNorm2d(128) def forward(self,x): x=self.conv1(x) x=self.bn1(x) x=self.relu(x) x=self.maxpool(x) x=self.conv2(x) x=self.bn2(x) return x model=RES() glb = nn.Sequential(*list(model.children())[:4])
有兩點資料的說明:這個類繼承了Module一定要用super函式
nn.Sequential函式裡面的引數一定是Module的子類,而list:list is not a Module subclass。所以不能當做引數,當然model.children()也是一樣:Module.children is not a Module subclass。這裡的*就起了作用,將list或者children的內容迭代的一個一個的傳進去,效果如下:
當然,我們還可以像最上面的那樣,選取裡面的幾個Module,例如[:4]也就是第0個到第3個.
動態匯入模組,使用importlib.import_module函式實際上是import了一個叫做resnet的檔案,下面的語句相當於 import xxx as resnet
當然這裡的xxx是該檔案的實際路徑
import importlib
resnet = importlib.import_module("torchvision.models.resnet")
resnet18=resnet.resnet18()
resnet34=resnet.resnet34()
resnet50=resnet.resnet50()
resnet101=resnet.resnet101()
resnet152=resnet.resnet152()
其他的模組有:
""" alexnet檔案 """ alexnet=importlib.import_module("torchvision.models.alexnet") alexnet=alexnet.alexnet() nn.Sequential(*alexnet.children()) """ vgg檔案 """ vgg=importlib.import_module("torchvision.models.vgg") vgg16=vgg.vgg16() # vgg11=vgg.vgg11(),vgg19=vgg.vgg19(),vgg13=vgg.vgg13()以及他們的bn形式 # vgg16_bn=vgg.vgg16_bn(),vgg11_bn=vgg.vgg11_bn(),vgg19_bn=vgg.vgg19_bn(),vgg13_bn=vgg.vgg13_bn() nn.Sequential(*vgg16.children()) """ densenet檔案 """ densenet=importlib.import_module("torchvision.models.densenet") densenet121=densenet.densenet121() # densenet169=densenet.densenet169(),densenet201=densenet.densenet201(),densenet161=densenet.densenet161() nn.Sequential(*densenet121.children()) """ inception檔案 """ inception=importlib.import_module("torchvision.models.inception") inception_v3=inception.inception_v3() nn.Sequential(*inception_v3.children()) """ squeezenet檔案 """ squeezenet=importlib.import_module("torchvision.models.squeezenet") squeezenet1_0=inception.squeezenet1_0() # squeezenet1_0=inception.squeezenet1_1() nn.Sequential(*squeezenet1_0.children())
還有一種匯入方式,是比較常用的,推薦的:
import torchvision.models as models
models.squeezenet1_0()
"""
models後面直接接的是網路
models的__init__檔案如下
"""
from .alexnet import *
from .resnet import *
from .vgg import *
from .squeezenet import *
from .inception import *
from .densenet import *
"""
可以看出來,匯入的是這5個檔案裡面的函式(類)
*代表想對應檔案的__all__,下面是各個檔案的該屬性以及訓練好的權重
"""
# alexnet
__all__ = ['AlexNet', 'alexnet']
model_urls = {
'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
}
# resnet
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152']
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
# vgg
__all__ = [
'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
'vgg19_bn', 'vgg19',]
model_urls = {
'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
}
# squeezenet
__all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1']
model_urls = {
'squeezenet1_0': 'https://download.pytorch.org/models/squeezenet1_0-a815701f.pth',
'squeezenet1_1': 'https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth',
}
# inception
__all__ = ['Inception3', 'inception_v3']
model_urls = {
# Inception v3 ported from TensorFlow
'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
}
# densenet
__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']
model_urls = {
'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
}
所有的模型預設都是不載入預訓練模型引數的,怎麼載入預訓練模型引數呢?很簡單,就在括號裡面的pretrained設定成True,如果僅僅是需要該結構而不需要預訓練模型引數作為初始化,那麼pretrained=False。
resnet50 = models.resnet50(pretrained=True)
推薦!這裡有一篇比較綜合https://blog.csdn.net/weixin_41278720/article/details/80759933
其中可以補充一點就是將引數進行下載,相比載入模型來說更加的節省資源
import torch.utils.model_zoo as model_zoo
def _load_pretrained_model(self):
pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'/home/zzp/SSD_ping/my-root-path/My-core-python/PretrainedWeights')
model_dict = {}
state_dict = self.state_dict()
for k, v in pretrain_dict.items():
if k in state_dict:
model_dict[k] = v
state_dict.update(model_dict)
self.load_state_dict(state_dict)