儲存和恢復神經網路
阿新 • • 發佈:2018-12-26
轉自莫煩大神,轉載原因是想把所有相關內容收集到自己的部落格中,方便系統的學習。
兩種儲存方法,1是儲存整個神經網路;2是隻儲存神經網路的所有引數。
一、儲存神經網路
1儲存整個神經網路。
torch.save(net1,"net1.pkl")
net1為我想要儲存的網路,net1.pkl為檔名,儲存的格式只能是.pkl
2,儲存神經網路引數
torch.save(net1.state_dict(),"net1_parmaer.pkl")
二、恢復神經網路
1恢復完整神經網路(直接load())
net2=torch.load("net1.pkl")
2.從引數中恢復神經網路
需先構建與所要恢復的神經網路相同結構,再load引數。
3,完整程式如下
import torch import matplotlib.pyplot as plt x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1) y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1) def save(): net1 = torch.nn.Sequential( torch.nn.Linear(1, 10), # 一層神經層 torch.nn.ReLU(), # 加激勵函式,relu相當於類 torch.nn.Linear(10, 1), ) optimizer=torch.optim.SGD(net1.parameters(),lr=0.5) loss_func=torch.nn.MSELoss() for t in range(100): prediction=net1(x) loss=loss_func(prediction,y) optimizer.zero_grad() loss.backward() optimizer.step()#畫圖 plt.figure(1, figsize=(10, 3)) plt.subplot(131) plt.title('Net1') plt.scatter(x.data.numpy(), y.data.numpy()) #實際資料 plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) #迴歸曲線 torch.save(net1,"net1.pkl") #儲存整個神經網路 torch.save(net1.state_dict(),"net1_parmaer.pkl") #儲存神經網路中的所有引數 def restore_net(): net2=torch.load("net1.pkl") prediction2=net2(x) plt.subplot(132) plt.title('Net2') plt.scatter(x.data.numpy(), y.data.numpy()) #實際資料 plt.plot(x.data.numpy(), prediction2.data.numpy(), 'r-', lw=5) #迴歸曲線 def restore_paramers(): net3=torch.nn.Sequential( torch.nn.Linear(1, 10), # 一層神經層 torch.nn.ReLU(), # 加激勵函式,relu相當於類 torch.nn.Linear(10, 1), ) net3.load_state_dict(torch.load("net1_parmaer.pkl")) #先構建網路在,再載入引數 prediction3 = net3(x) plt.subplot(133) plt.title('Net3') plt.scatter(x.data.numpy(), y.data.numpy()) # 實際資料 plt.plot(x.data.numpy(), prediction3.data.numpy(), 'r-', lw=5) # 迴歸曲線 plt.show() save() restore_net()
執行結果: