1. 程式人生 > >PyTorch原始碼解讀之torchvision.models

PyTorch原始碼解讀之torchvision.models

這篇部落格介紹torchvision.models。torchvision.models這個包中包含alexnet、densenet、inception、resnet、squeezenet、vgg等常用的網路結構,並且提供了預訓練模型,可以通過簡單呼叫來讀取網路結構和預訓練模型。

使用例子:

import torchvision
model = torchvision.models.resnet50(pretrained=True)

這樣就匯入了resnet50的預訓練模型了。如果只需要網路結構,不需要用預訓練模型的引數來初始化,那麼就是:

model = torchvision.models.resnet50(pretrained=False
)

如果要匯入densenet模型也是同樣的道理,比如匯入densenet169,且不需要是預訓練的模型:

model = torchvision.models.densenet169(pretrained=False)

由於pretrained引數預設是False,所以等價於:

model = torchvision.models.densenet169()

不過為了程式碼清晰,最好還是加上引數賦值。

接下來以匯入resnet50為例介紹具體匯入模型時候的原始碼。執行model = torchvision.models.resnet50(pretrained=True)的時候,是通過models包下的resnet.py指令碼進行的,原始碼如下:

首先是匯入必要的庫,其中model_zoo是和匯入預訓練模型相關的包,另外all變數定義了可以從外部import的函式名或類名。這也是前面為什麼可以用torchvision.models.resnet50()來呼叫的原因。model_urls這個字典是預訓練模型的下載地址。

import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo

__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
           'resnet152'
] model_urls = { 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', }

接下來就是resnet50這個函數了,引數pretrained預設是False。首先model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)是構建網路結構,Bottleneck是另外一個構建bottleneck的類,在ResNet網路結構的構建中有很多重複的子結構,這些子結構就是通過Bottleneck類來構建的,後面會介紹。然後如果引數pretrained是True,那麼就會通過model_zoo.py中的load_url函式根據model_urls字典下載或匯入相應的預訓練模型。最後通過呼叫model的load_state_dict方法用預訓練的模型引數來初始化你構建的網路結構,這個方法就是PyTorch中通用的用一個模型的引數初始化另一個模型的層的操作。load_state_dict方法還有一個重要的引數是strict,該引數預設是True,表示預訓練模型的層和你的網路結構層嚴格對應相等(比如層名和維度)。

def resnet50(pretrained=False, **kwargs):
    """Constructs a ResNet-50 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
    return model

其他resnet18、resnet101等函式和resnet50基本類似,差別主要是在:1、構建網路結構的時候block的引數不一樣,比如resnet18中是[2, 2, 2, 2],resnet101中是[3, 4, 23, 3]。2、呼叫的block類不一樣,比如在resnet50、resnet101、resnet152中呼叫的是Bottleneck類,而在resnet18和resnet34中呼叫的是BasicBlock類,這兩個類的區別主要是在residual結果中卷積層的數量不同,這個是和網路結構相關的,後面會詳細介紹。3、如果下載預訓練模型的話,model_urls字典的鍵不一樣,對應不同的預訓練模型。因此接下來分別看看如何構建網路結構和如何匯入預訓練模型。

def resnet18(pretrained=False, **kwargs):
    """Constructs a ResNet-18 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
    return model

def resnet101(pretrained=False, **kwargs):
    """Constructs a ResNet-101 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
    return model

構建ResNet網路是通過ResNet這個類進行的。首先還是繼承PyTorch中網路的基類:torch.nn.Module,其次主要的是重寫初始化__init__和forward方法。在初始化__init__中主要是定義一些層的引數。forward方法中主要是定義資料在層之間的流動順序,也就是層的連線順序。另外還可以在類中定義其他私有方法用來模組化一些操作,比如這裡的_make_layer方法是用來構建ResNet網路中的4個blocks。_make_layer方法的第一個輸入block是Bottleneck或BasicBlock類,第二個輸入是該blocks的輸出channel,第三個輸入是每個blocks中包含多少個residual子結構,因此layers這個列表就是前面resnet50的[3, 4, 6, 3]。
_make_layer方法中比較重要的兩行程式碼是:1、layers.append(block(self.inplanes, planes, stride, downsample)),該部分是將每個blocks的第一個residual結構儲存在layers列表中。2、 for i in range(1, blocks): layers.append(block(self.inplanes, planes)),該部分是將每個blocks的剩下residual 結構儲存在layers列表中,這樣就完成了一個blocks的構造。這兩行程式碼中都是通過Bottleneck這個類來完成每個residual的構建,接下來介紹Bottleneck類。

class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000):
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

從前面的ResNet類可以看出,在構造ResNet網路的時候,最重要的是Bottleneck這個類,因為ResNet是由residual結構組成的,而Bottleneck類就是完成residual結構的構建。同樣Bottlenect還是繼承了torch.nn.Module類,且重寫了__init__和forward方法。從forward方法可以看出,bottleneck就是我們熟悉的3個主要的卷積層、BN層和啟用層,最後的out += residual就是element-wise add的操作。

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

BasicBlock類和Bottleneck類類似,前者主要是用來構建ResNet18和ResNet34網路,因為這兩個網路的residual結構只包含兩個卷積層,沒有Bottleneck類中的bottleneck概念。因此在該類中,第一個卷積層採用的是kernel_size=3的卷積,如conv3x3函式所示。

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

介紹完如何構建網路,接下來就是如何獲取預訓練模型。前面提到這一行程式碼:if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])),主要就是通過model_zoo.py中的load_url函式根據model_urls字典匯入相應的預訓練模型,models_zoo.py指令碼的github地址:https://github.com/pytorch/pytorch/blob/master/torch/utils/model_zoo.py
load_url函式原始碼如下。首先model_dir是下載下來的模型的儲存地址,如果沒有指定的話就會儲存在專案的.torch目錄下,最好指定。cached_file是儲存模型的路徑加上模型名稱。接下來的 if not os.path.exists(cached_file)語句用來判斷是否指定目錄下已經存在要下載模型,如果已經存在,就直接呼叫torch.load介面匯入模型,如果不存在,則從網上下載,下載是通過_download_url_to_file(url, cached_file, hash_prefix, progress=progress)進行的,不再細講。重點在於模型匯入是通過torch.load()介面來進行的,不管你的模型是從網上下載的還是本地已有的。

def load_url(url, model_dir=None, map_location=None, progress=True):
    r"""Loads the Torch serialized object at the given URL.

    If the object is already present in `model_dir`, it's deserialized and
    returned. The filename part of the URL should follow the naming convention
    ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
    digits of the SHA256 hash of the contents of the file. The hash is used to
    ensure unique names and to verify the contents of the file.

    The default value of `model_dir` is ``$TORCH_HOME/models`` where
    ``$TORCH_HOME`` defaults to ``~/.torch``. The default directory can be
    overriden with the ``$TORCH_MODEL_ZOO`` environment variable.

    Args:
        url (string): URL of the object to download
        model_dir (string, optional): directory in which to save the object
        map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load)
        progress (bool, optional): whether or not to display a progress bar to stderr

    Example:
        >>> state_dict = torch.utils.model_zoo.load_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')

    """
    if model_dir is None:
        torch_home = os.path.expanduser(os.getenv('TORCH_HOME', '~/.torch'))
        model_dir = os.getenv('TORCH_MODEL_ZOO', os.path.join(torch_home, 'models'))
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    parts = urlparse(url)
    filename = os.path.basename(parts.path)
    cached_file = os.path.join(model_dir, filename)
    if not os.path.exists(cached_file):
        sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
        hash_prefix = HASH_REGEX.search(filename).group(1)
        _download_url_to_file(url, cached_file, hash_prefix, progress=progress)
    return torch.load(cached_file, map_location=map_location)