1. 程式人生 > >【pytorch】模型的搭建儲存載入

【pytorch】模型的搭建儲存載入

使用pytorch進行網路模型的搭建、儲存與載入,是非常快速、方便的。

搭建ConvNet

所有的網路都要繼承torch.nn.Module,然後在建構函式中使用torch.nn中的提供的介面定義layer的屬性,最後,在forward函式中將各個layer連線起來。

下面,以LeNet為例:

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1   = nn.Linear(16*5*5, 120)
        self.fc2   = nn.Linear(120, 84)
        self.fc3   = nn.Linear(84, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        out = self.fc3(x)
        return out

這樣一來,我們就搭建好了網路模型,是不是很簡潔明瞭呢?此外,還可以使用torch.nn.Sequential,更方便進行模組化的定義,如下:

class LeNetSeq(nn.Module):
    def __init__(self):
        super(LeNetSeq, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 6, 5),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(6, 16, 5),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )

        self.fc = nn.Sequential(
            nn.Linear(16*5*5, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, 10)
        )
        
    def forward(self, x):
        x = self.conv(x)
        x = out.view(x.size(0), -1)
        out = self.fc(x)
        return out

Module有很多屬性,可以檢視權重、引數等等;如下:

net = lenet.LeNet()
print(net)

for param in net.parameters():
     print(type(param.data), param.size())
     print(list(param.data)) 

print(net.state_dict().keys())
#引數的keys

for key in net.state_dict():#模型引數
    print key, 'corresponds to', list(net.state_dict()[key])

那麼,如何進行引數初始化呢?使用 torch.nn.init ,如下:
def initNetParams(net):
    '''Init net parameters.'''
    for m in net.modules():
        if isinstance(m, nn.Conv2d):
            init.xavier_uniform(m.weight)
            if m.bias:
                init.constant(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            init.constant(m.weight, 1)
            init.constant(m.bias, 0)
        elif isinstance(m, nn.Linear):
            init.normal(m.weight, std=1e-3)
            if m.bias:
                init.constant(m.bias, 0)

initNetParams(net)


儲存ConvNet

使用torch.save()對網路結構和模型引數的儲存,有兩種儲存方式:

  • 儲存整個神經網路的的結構資訊和模型引數資訊,save的物件是網路net;
  • 儲存神經網路的訓練模型引數,save的物件是net.state_dict()。
torch.save(net1, 'net.pkl')  # 儲存整個神經網路的結構和模型引數    
torch.save(net1.state_dict(), 'net_params.pkl') # 只儲存神經網路的模型引數    


載入ConvNet

對應上面兩種儲存方式,過載方式也有兩種。

  • 對應第一種完整網路結構資訊,過載的時候通過torch.load(‘.pth’)直接初始化新的神經網路物件即可。
  • 對應第二種只儲存模型引數資訊,需要首先匯入對應的網路,通過net.load_state_dict(torch.load('.pth'))完成模型引數的過載。

在網路比較大的時候,第一種方法會花費較多的時間,所佔的儲存空間也比較大。

# 儲存和載入整個模型  
torch.save(model_object, 'model.pth')  
model = torch.load('model.pth')  

# 僅儲存和載入模型引數  
torch.save(model_object.state_dict(), 'params.pth')  
model_object.load_state_dict(torch.load('params.pth'))