1. 程式人生 > >PyTorch-網路的建立,預訓練模型的載入

PyTorch-網路的建立,預訓練模型的載入

本文是PyTorch使用過程中的的一些總結,有以下內容:

  • 構建網路模型的方法
  • 網路層的遍歷
  • 各層引數的遍歷
  • 模型的儲存與載入
  • 從預訓練模型為網路引數賦值

主要涉及到以下函式的使用

  • add_module,ModulesList,Sequential 模型建立
  • modules(),named_modules(),children(),named_children() 訪問模型的各個子模組
  • parameters(),named_parameters() 網路引數的遍歷
  • save(),load()state_dict() 模型的儲存與載入

構建網路

torch.nn.Module是所有網路的基類,在Pytorch實現的Model都要繼承該類。而且,Module

是可以包含其他的Module的,以樹形的結構來表示一個網路結構。

簡單的定義一個網路Model

class Model(nn.Module):
    def __init__(self):
        super(Model,self).__init__()
        self.conv1 = nn.Conv2d(3,64,3)
        self.conv2 = nn.Conv2d(64,64,3)

    def forward(self,x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x

Model中兩個屬性conv1conv2是兩個卷積層,在正向傳播的過程中,再依次呼叫這兩個卷積層。

除了使用Model的屬性來為網路新增層外,還可以使用add_module將網路層新增到網路中。

class Model(nn.Module):
    def __init__(self):
        super(Model,self).__init__()
        self.conv1 = nn.Conv2d(3,64,3)
        self.conv2 = nn.Conv2d(64,64,3)

        self.add_module("maxpool1",nn.MaxPool2d(2,2))
        self.add_module("covn3",nn.Conv2d(64,128,3))
        self.add_module("conv4",nn.Conv2d(128,128,3))

    def forward(self,x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.maxpool1(x)
        x = self.conv3(x)
        x = self.conv4(x)
        return x

add_module(name,layer)在正向傳播的過程中可以使用新增時的name來訪問改layer。

這樣一個個的新增layer,在簡單的網路中還行,但是對於負責的網路層很多的網路來說就需要敲很多重複的程式碼了。 這就需要使用到torch.nn.ModuleListtorch.nn.Sequential

使用ModuleListSequential可以方便新增子網路到網路中,但是這兩者還是有所不同的。

ModuleList

ModuleList是以list的形式儲存sub-modules或者網路層,這樣就可以先將網路需要的layer構建好儲存到一個list,然後通過ModuleList方法新增到網路中。

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule,self).__init__()

        # 構建layer的list
        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])

    def forward(self,x):

        # 正向傳播,使用遍歷每個Layer
        for i, l in enumerate(self.linears):
            x = self.linears[i // 2](x) + l(x)

        return x

使用[nn.Linear(10, 10) for i in range(10)]構建要給Layer的list,然後使用ModuleList新增到網路中,在正向傳播的過程中,遍歷該list

更為方便的是,可以提前配置後,所需要的各個Layer的屬性,然後讀取配置建立list,然後使用ModuleList將配置好的網路層新增到網路中。 以VGG為例:

vgg_cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M',
           512, 512, 512, 'M']

def vgg(cfg, i, batch_norm=False):
    layers = []
    in_channels = i
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        elif v == 'C':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return layers

class Model1(nn.Module):
    def __init__(self):
        super(Model1,self).__init__()

        self.vgg = nn.ModuleList(vgg(vgg_cfg,3))

    def forward(self,x):

        for l in self.vgg:
            x = l(x)
m1 = Model1()
print(m1)

讀取配置好的網路結構vgg_cfg然後,建立相應的Layer List,使用ModuleList加入到網路中。這樣就可以很靈活的建立不同的網路。

這裡需要注意的是,ModuleList是將Module加入網路中,需要自己手動的遍歷進行每一個Moduleforward

Sequential

一個時序容器。Modules 會以他們傳入的順序被新增到容器中。當然,也可以傳入一個OrderedDict一個時序容器。Modules 會以他們傳入的順序被新增到容器中。當然,也可以傳入一個OrderedDict
Sequential也是一次加入多個Module到網路中中,和ModuleList不同的是,它接受多個Module依次加入到網路中,還可以接受字典作為引數,例如:

# Example of using Sequential
        model = nn.Sequential(
                  nn.Conv2d(1,20,5),
                  nn.ReLU(),
                  nn.Conv2d(20,64,5),
                  nn.ReLU()
                )

# Example of using Sequential with OrderedDict
model = nn.Sequential(OrderedDict([
    ('conv1', nn.Conv2d(1,20,5)),
    ('relu1', nn.ReLU()),
    ('conv2', nn.Conv2d(20,64,5)),
    ('relu2', nn.ReLU())
    ]))

另一個是,Sequential中實現了新增Module的forward,不需要手動的迴圈呼叫了。這點相比ModuleList較為方便。

總結

常見的有三種方法來新增子Module到網路中

  • 單獨新增一個Module,可以使用屬性或者add_module方法。
  • ModuleList可以將一個Module的List加入到網路中,自由度較高,但是需要手動的遍歷ModuleList進行forward
  • Sequential按照順序將將Module加入到網路中,也可以處理字典。 相比於ModuleList不需要自己實現forward

遍歷網路結構

可以使用以下2對4個方法來訪問網路層所有的Modules

  • modules()named_modules()
  • children()named_children()

modules方法

簡單的定義一個如下網路:

class Model(nn.Module):
    def __init__(self):
        super(Model,self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3,out_channels=64,kernel_size=3)
        self.conv2 = nn.Conv2d(64,64,3)
        self.maxpool1 = nn.MaxPool2d(2,2)

        self.features = nn.Sequential(OrderedDict([
            ('conv3', nn.Conv2d(64,128,3)),
            ('conv4', nn.Conv2d(128,128,3)),
            ('relu1', nn.ReLU())
        ]))

    def forward(self,x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.maxpool1(x)
        x = self.features(x)

        return x

modules()方法,返回一個包含當前模型所有模組的迭代器,這個是遞迴的返回網路中的所有Module。使用如下語句

    m = Model()
    for idx,m in enumerate(m.modules()):
        print(idx,"-",m)

其結果為:

0 - Model(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
  (maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (features): Sequential(
    (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
    (conv4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
    (relu1): ReLU()
  )
)
1 - Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
2 - Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
3 - MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
4 - Sequential(
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
  (conv4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
  (relu1): ReLU()
)
5 - Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
6 - Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
7 - ReLU()

輸出結果解析:

  • 0-Model 整個網路模組
  • 1-2-3-4 為網路的4個子模組,注意4 - Sequential仍然包含有子模組
  • 5-6-7為模組4 - Sequential的子模組

可以看出modules()是遞迴的返回網路的各個module,從最頂層直到最後的葉子module。

named_modules()的功能和modules()的功能類似,不同的是它返回內容有兩部分:module的名稱以及module。

children()方法

modules()不同,children()只返回當前模組的子模組,不會遞迴子模組。

    for idx,m in enumerate(m.children()):
        print(idx,"-",m)

其輸出為:

0 - Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
1 - Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
2 - MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
3 - Sequential(
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
  (conv4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
  (relu1): ReLU()
)

子模組3-Sequential仍然有子模組,children()沒有遞迴的返回。
named_children()children()的功能類似,不同的是其返回兩部分內容:模組的名稱以及模組本身。

網路的引數

方法parameters()返回一個包含模型所有引數的迭代器。一般用來當作optimizer的引數。

    for p in m.parameters():
        print(type(p.data),p.size())

其輸出為:

<class 'torch.Tensor'> torch.Size([128, 64, 3, 3])
<class 'torch.Tensor'> torch.Size([128])
<class 'torch.Tensor'> torch.Size([128, 128, 3, 3])
<class 'torch.Tensor'> torch.Size([128])

包含網路中的所有的權值矩陣引數以及偏置引數。 對網路進行訓練時需要將parameters()作為優化器optimizer的引數。

optimizer = torch.optim.SGD(m1.parameters(),lr = args.lr,momentum=args.momentum,weight_decay=args.weight_decay)

parameters()返回網路的所有引數,主要是提供給optimizer用的。而要取得網路某一層的引數或者引數進行一些特殊的處理(如fine-tuning),則使用named_parameters()更為方便些。

named_parameters()返回引數的名稱及引數本身,可以按照引數名對一些引數進行處理。

以上面的vgg網路為例:

for k,v in m1.named_parameters():
    print(k,v.size())

named_parameters返回的是鍵值對,k為引數的名稱 ,v為引數本身。輸出結果為:

vgg.0.weight torch.Size([64, 3, 3, 3])
vgg.0.bias torch.Size([64])
vgg.2.weight torch.Size([64, 64, 3, 3])
vgg.2.bias torch.Size([64])
vgg.5.weight torch.Size([128, 64, 3, 3])
vgg.5.bias torch.Size([128])
vgg.7.weight torch.Size([128, 128, 3, 3])
vgg.7.bias torch.Size([128])
vgg.10.weight torch.Size([256, 128, 3, 3])
vgg.10.bias torch.Size([256])
vgg.12.weight torch.Size([256, 256, 3, 3])
vgg.12.bias torch.Size([256])
vgg.14.weight torch.Size([256, 256, 3, 3])
vgg.14.bias torch.Size([256])
vgg.17.weight torch.Size([512, 256, 3, 3])
vgg.17.bias torch.Size([512])
vgg.19.weight torch.Size([512, 512, 3, 3])
vgg.19.bias torch.Size([512])
vgg.21.weight torch.Size([512, 512, 3, 3])
vgg.21.bias torch.Size([512])
vgg.24.weight torch.Size([512, 512, 3, 3])
vgg.24.bias torch.Size([512])
vgg.26.weight torch.Size([512, 512, 3, 3])
vgg.26.bias torch.Size([512])
vgg.28.weight torch.Size([512, 512, 3, 3])
vgg.28.bias torch.Size([512])

引數名的命名規則屬性名稱.引數屬於的層的編號.weight/bias。 這在fine-tuning的時候,給一些特定的層的引數賦值是非常方便的,這點在後面在載入預訓練模型時會看到。

模型的儲存與載入

PyTorch使用torch.savetorch.load方法來儲存和載入網路,而且網路結構和引數可以分開的儲存和載入。

  • 儲存網路結構及其引數
torch.save(model,'model.pth') # 儲存
model = torch.load("model.pth") # 載入
  • 只加載模型引數,網路結構從程式碼中建立
torch.save(model.state_dict(),"model.pth") # 儲存引數
model = model() # 程式碼中建立網路結構
params = torch.load("model.pth") # 載入引數
model.load_state_dict(params) # 應用到網路結構中

載入預訓練模型

PyTorch中的torchvision裡有很多常用網路的預訓練模型,例如:vgg,resnet,googlenet等,可以方便的使用這些預訓練模型進行微調。

# PyTorch中的torchvision裡有很多常用的模型,可以直接呼叫:
import torchvision.models as models
 
resnet101 = models.resnet18(pretrained=True)
alexnet = models.alexnet()
squeezenet = models.squeezenet1_0()

有時候只需要載入預訓練模型的部分引數,可以使用引數名作為過濾條件,如下

resnet152 = models.resnet152(pretrained=True)
pretrained_dict = resnet152.state_dict()
"""載入torchvision中的預訓練模型和引數後通過state_dict()方法提取引數
   也可以直接從官方model_zoo下載:
   pretrained_dict = model_zoo.load_url(model_urls['resnet152'])"""
model_dict = model.state_dict()
# 將pretrained_dict裡不屬於model_dict的鍵剔除掉
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 更新現有的model_dict
model_dict.update(pretrained_dict)
# 載入我們真正需要的state_dict
model.load_state_dict(model_dict)

model.state_dict()返回一個python的字典物件,將每一層與它的對應引數建立對映關係(如model的每一層的weights及偏置等等)。注意,只有有引數訓練的層才會被儲存。

上述的載入方式,是按照引數名類匹配過濾的,但是對於一些引數名稱無法完全匹配,或者在預訓練模型的基礎上新新增的一些層,這些層無法從預訓練模型中獲取引數,需要初始化。

仍然以上述的vgg為例,在標準的vgg16的特徵提取後面,新新增兩個卷積層,這兩個卷積層的引數需要進行初始化。

vgg = torch.load("vgg.pth") # 載入預訓練模型

for k,v in m1.vgg.named_parameters():
    k = "features.{}".format(k) # 引數名稱
    if k in vgg.keys():
        v.data = vgg[k].data # 直接載入預訓練引數
    else:
        if k.find("weight") >= 0:
            nn.init.xavier_normal_(v.data) # 沒有預訓練,則使用xavier初始化
        else:
            nn.init.constant_(v.data,0) # bias 初始化為0