pytorch(四):神經網路的儲存與提取
阿新 • • 發佈:2019-02-04
將神經網路訓練好之後,如何儲存它呢,儲存它之後有如何提取它呢?
如下圖所示,net1是訓練好的神經網路,有兩種方式儲存它:1.儲存整個訓練好的神經網路,2.儲存神經網路的最終引數
net2是根據第1種方式儲存的。net2是根據第2種方式儲存的
原始碼:
# 引入模組 import torch import torch.nn.functional as f from torch.autograd import Variable import matplotlib.pyplot as plt # 生成一些假資料 x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # 神經網路只能接受二維資料的輸入 y = pow(x, 2) + 0.2*torch.rand(x.size()) # 後半部分製造噪音 x, y = Variable(x), Variable(y) # 訓練神經網路時只能接受Variable形式輸入 # 定義儲存函式 def save(): net1 = torch.nn.Sequential( torch.nn.Linear(1, 10), torch.nn.ReLU(), torch.nn.Linear(10, 1) ) optimizer = torch.optim.SGD(net1.parameters(), lr=0.5) loss_func = torch.nn.MSELoss() for i in range(100): prediction = net1(x) # 喂資料x給net1 loss = loss_func(prediction, y) optimizer.zero_grad() # 將上面運算過程中的grad清零 loss.backward() # 誤差反向傳遞 optimizer.step() # 將新引數作用於神經網路 # 繪圖 plt.figure(figsize=(10, 3)) # 設定影象的大小 plt.subplot(131) plt.title('net1', color='red', size=20) plt.scatter(x.data.numpy(), y.data.numpy()) plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) plt.text(0.3, 0, 'loss=%.4f' % loss, fontdict={'color': 'red', 'size': 10}) # 儲存net1的兩種方式 torch.save(net1, 'net1.pkl') # 方式1:儲存整個神經網路 torch.save(net1.state_dict(), 'net1_parameters.pkl') # 方式2:儲存神經網路的引數 # 定義提取整個神經網路的函式 def restore_net(): net2 = torch.load('net1.pkl') # 載入檔案net1.pkl, 將其內容賦值給net2 prediction = net2(x) loss_func = torch.nn.MSELoss() loss = loss_func(prediction, y) # 繪製net2結果圖形 plt.subplot(132) plt.title('net2', color='red', size=20) plt.scatter(x.data.numpy(), y.data.numpy()) plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) plt.text(0.3, 0, 'loss=%.4f' % loss, fontdict={'color': 'red', 'size': 10}) # 定義提取神經網路狀態引數的函式 def restore_net_parameters(): net3 = torch.nn.Sequential( torch.nn.Linear(1, 10), torch.nn.ReLU(), torch.nn.Linear(10, 1) ) # 構造net3的基本框架 net3.load_state_dict(torch.load('net1_parameters.pkl')) # 提取net1的狀態引數,將狀態引數給net3 prediction = net3(x) loss_func = torch.nn.MSELoss() loss = loss_func(prediction, y) # 繪製net3結果圖形 plt.subplot(133) plt.title('net3', color='red', size=20) plt.scatter(x.data.numpy(), y.data.numpy()) plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) plt.text(0.3, 0, 'loss=%.4f' % loss, fontdict={'color': 'red', 'size': 10}) # 呼叫函式 save() restore_net() restore_net_parameters() plt.show() # 將三個函式繪製的圖形顯示出來
注意:將plt.show()放置在最後,能顯示出三幅影象連在一起的。若在每個定義的函式的後面均加上plt.show(),三幅影象是分開顯示的,無法連成一個整體。