【pytorch】模型的搭建儲存載入
阿新 • • 發佈:2019-01-11
使用pytorch進行網路模型的搭建、儲存與載入,是非常快速、方便的。
那麼,如何進行引數初始化呢?使用 torch.nn.init ,如下:
搭建ConvNet
所有的網路都要繼承torch.nn.Module,然後在建構函式中使用torch.nn中的提供的介面定義layer的屬性,最後,在forward函式中將各個layer連線起來。
下面,以LeNet為例:
class LeNet(nn.Module): def __init__(self): super(LeNet, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16*5*5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = F.relu(self.conv1(x)) x = F.max_pool2d(x, 2) x = F.relu(self.conv2(x)) x = F.max_pool2d(x, 2) x = x.view(x.size(0), -1) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) out = self.fc3(x) return out
這樣一來,我們就搭建好了網路模型,是不是很簡潔明瞭呢?此外,還可以使用torch.nn.Sequential,更方便進行模組化的定義,如下:
class LeNetSeq(nn.Module): def __init__(self): super(LeNetSeq, self).__init__() self.conv = nn.Sequential( nn.Conv2d(3, 6, 5), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(6, 16, 5), nn.ReLU(), nn.MaxPool2d(2), ) self.fc = nn.Sequential( nn.Linear(16*5*5, 120), nn.ReLU(), nn.Linear(120, 84), nn.ReLU(), nn.Linear(84, 10) ) def forward(self, x): x = self.conv(x) x = out.view(x.size(0), -1) out = self.fc(x) return out
Module有很多屬性,可以檢視權重、引數等等;如下:
net = lenet.LeNet() print(net) for param in net.parameters(): print(type(param.data), param.size()) print(list(param.data)) print(net.state_dict().keys()) #引數的keys for key in net.state_dict():#模型引數 print key, 'corresponds to', list(net.state_dict()[key])
那麼,如何進行引數初始化呢?使用 torch.nn.init ,如下:
def initNetParams(net):
'''Init net parameters.'''
for m in net.modules():
if isinstance(m, nn.Conv2d):
init.xavier_uniform(m.weight)
if m.bias:
init.constant(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant(m.weight, 1)
init.constant(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal(m.weight, std=1e-3)
if m.bias:
init.constant(m.bias, 0)
initNetParams(net)
儲存ConvNet
使用torch.save()對網路結構和模型引數的儲存,有兩種儲存方式:
- 儲存整個神經網路的的結構資訊和模型引數資訊,save的物件是網路net;
- 儲存神經網路的訓練模型引數,save的物件是net.state_dict()。
torch.save(net1, 'net.pkl') # 儲存整個神經網路的結構和模型引數
torch.save(net1.state_dict(), 'net_params.pkl') # 只儲存神經網路的模型引數
載入ConvNet
對應上面兩種儲存方式,過載方式也有兩種。
- 對應第一種完整網路結構資訊,過載的時候通過torch.load(‘.pth’)直接初始化新的神經網路物件即可。
- 對應第二種只儲存模型引數資訊,需要首先匯入對應的網路,通過net.load_state_dict(torch.load('.pth'))完成模型引數的過載。
在網路比較大的時候,第一種方法會花費較多的時間,所佔的儲存空間也比較大。
# 儲存和載入整個模型
torch.save(model_object, 'model.pth')
model = torch.load('model.pth')
# 僅儲存和載入模型引數
torch.save(model_object.state_dict(), 'params.pth')
model_object.load_state_dict(torch.load('params.pth'))