PyTorch—torchvision.models匯入預訓練模型與殘差網路講解
文章目錄
PyTorch框架中torchvision模組下有:torchvision.datasets、torchvision.models、torchvision.transforms這3個子包。
關於詳情請參考官網:
具體程式碼可以參考github: https://github.com/pytorch/vision/tree/master/torchvision。
torchvision.models
此模組下有常用的 alexnet、densenet、inception、resnet、squeezenet、vgg(關於網路詳情請檢視)等常用的網路結構,並且提供了預訓練模型,我們可以通過簡單呼叫來讀取網路結構和預訓練模型,同時使用fine tuning(微調)來使用。
關於 fine tuning 可以檢視
1. 模組呼叫
import torchvision
"""
如果你需要用預訓練模型,設定pretrained=True
如果你不需要用預訓練模型,設定pretrained=False,預設是False,你可以不寫
"""
model = torchvision.models.resnet50(pretrained=True)
model = torchvision.models.resnet50()
# 你也可以匯入densenet模型。且不需要是預訓練的模型
model = torchvision.models.densenet169(pretrained=False)
2. 原始碼解析
以匯入resnet50為例,介紹具體匯入模型時候的原始碼。
執行 model = torchvision.models.resnet50(pretrained=True)
的時候,是通過models包下的resnet.py指令碼進行的,原始碼如下:
首先是匯入必要的庫,其中model_zoo是和匯入預訓練模型相關的包,另外all變數定義了可以從外部import的函式名或類名。這也是前面為什麼可以用torchvision.models.resnet50()來呼叫的原因。
model_urls這個字典是預訓練模型的下載地址。
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152']
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
接下來就是resnet50這個函數了,引數pretrained預設是False。
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
是構建網路結構,Bottleneck是另外一個構建bottleneck的類,在ResNet網路結構的構建中有很多重複的子結構,這些子結構就是通過Bottleneck類來構建的,後面會介紹。- 如果引數pretrained是True,那麼就會通過model_zoo.py中的load_url函式根據model_urls字典下載或匯入相應的預訓練模型。
- 通過呼叫model的
load_state_dict
方法用預訓練的模型引數來初始化你構建的網路結構,這個方法就是PyTorch中通用的用一個模型的引數初始化另一個模型的層的操作。load_state_dict方法還有一個重要的引數是strict,該引數預設是True,表示預訓練模型的層和你的網路結構層嚴格對應相等(比如層名和維度)。
def resnet50(pretrained=False, **kwargs):
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
return model
其他resnet18、resnet101等函式和resnet50基本類似。
差別主要是在:
1、構建網路結構的時候block的引數不一樣,比如resnet18中是[2, 2, 2, 2],resnet101中是[3, 4, 23, 3]。
2、呼叫的block類不一樣,比如在resnet50、resnet101、resnet152中呼叫的是Bottleneck類,而在resnet18和resnet34中呼叫的是BasicBlock類,這兩個類的區別主要是在residual結果中卷積層的數量不同,這個是和網路結構相關的,後面會詳細介紹。
3、如果下載預訓練模型的話,model_urls字典的鍵不一樣,對應不同的預訓練模型。因此接下來分別看看如何構建網路結構和如何匯入預訓練模型。
# pretrained (bool): If True, returns a model pre-trained on ImageNet
def resnet18(pretrained=False, **kwargs):
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
return model
def resnet101(pretrained=False, **kwargs):
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
return model
3. ResNet類
繼承PyTorch中網路的基類:torch.nn.Module :
- 構建ResNet網路是通過ResNet這個類進行的。
- 其次主要的是重寫初始化
__init__()
和forward()
。
__init __()
中主要是定義一些層的引數。
forward()
中主要是定義資料在層之間的流動順序,也就是層的連線順序。
另外還可以在類中定義其他私有方法用來模組化一些操作,比如這裡的_make_layer()
是用來構建ResNet網路中的4個blocks。
_make_layer()
:
第一個輸入block是Bottleneck或BasicBlock類,
第二個輸入是該blocks的輸出channel,
第三個輸入是每個blocks中包含多少個residual子結構,因此layers這個列表就是前面resnet50的[3, 4, 6, 3]。
_make_layer()
方法中比較重要的兩行程式碼是:
1、layers.append(block(self.inplanes, planes, stride, downsample))
,該部分是將每個blocks的第一個residual結構儲存在layers列表中。
2、for i in range(1, blocks): layers.append(block(self.inplanes, planes)),
該部分是將每個blocks的剩下residual 結構儲存在layers列表中,這樣就完成了一個blocks的構造。
這兩行程式碼中都是通過Bottleneck這個類來完成每個residual的構建,接下來介紹Bottleneck類。
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000):
self.inplanes = 64
super(ResNet, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AvgPool2d(7, stride=1)
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
4. Bottlenect類
從前面的ResNet類可以看出,在構造ResNet網路的時候,最重要的是Bottleneck這個類,因為ResNet是由residual結構組成的,而Bottleneck類就是完成residual結構的構建。同樣Bottlenect還是繼承了torch.nn.Module類,且重寫了__init__和forward方法。從forward方法可以看出,bottleneck 就是我們熟悉的3個主要的卷積層、BN層和啟用層,最後的out += residual就是element-wise add的操作。
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
5. BasicBlock類
BasicBlock類和Bottleneck類類似,BasicBlock類主要是用來構建ResNet18和ResNet34網路,因為這兩個網路的residual結構只包含兩個卷積層,沒有Bottleneck類中的bottleneck概念。因此在該類中,第一個卷積層採用的是kernel_size=3的卷積,如conv3x3函式所示。
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
6. 獲取預訓練模型
前面提到這一行程式碼:
if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
,主要就是通過model_zoo.py中的load_url函式根據model_urls字典匯入相應的預訓練模型,models_zoo.py指令碼的github地址:https://github.com/pytorch/pytorch/blob/master/torch/utils/model_zoo.py。
load_url函式原始碼如下。
- 首先model_dir是下載模型儲存地址,如果沒有指定則儲存在專案的.torch目錄下,最好指定。cached_file是儲存模型的路徑加上模型名稱。
- 接下來的 if not os.path.exists(cached_file)語句用來判斷是否指定目錄下已經存在要下載模型,如果已經存在,就直接呼叫torch.load介面匯入模型,如果不存在,則從網上下載。
- 下載是通過
_download_url_to_file(url, cached_file, hash_prefix, progress=progress)
進行的,不再細講。重點在於模型匯入是通過torch.load()介面來進行的,不管你的模型是從網上下載的還是本地已有的。
def load_url(url, model_dir=None, map_location=None, progress=True):
"""
Args:
url (string): URL of the object to download
model_dir (string, optional): directory in which to save the object
map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load)
progress (bool, optional): whether or not to display a progress bar to stderr
Example:
>>> state_dict = torch.utils.model_zoo.load_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
"""
if model_dir is None:
torch_home = os.path.expanduser(os.getenv('TORCH_HOME', '~/.torch'))
model_dir = os.getenv('TORCH_MODEL_ZOO', os.path.join(torch_home, 'models'))
if not os.path.exists(model_dir):
os.makedirs(model_dir)
parts = urlparse(url)
filename = os.path.basename(parts.path)
cached_file = os.path.join(model_dir, filename)
if not os.path.exists(cached_file):
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
hash_prefix = HASH_REGEX.search(filename).group(1)
_download_url_to_file(url, cached_file, hash_prefix, progress=progress)
return torch.load(cached_file, map_location=map_location)
鳴謝
https://blog.csdn.net/u014380165/article/details/79119664