1. 程式人生 > 其它 >torch.save torch.load 載入和儲存模型

torch.save torch.load 載入和儲存模型

https://pytorch123.com/ThirdSection/SaveModel/ 這個連結非常的詳細!

1、#儲存整個網路 torch.save(net, PATH)

# 儲存網路中的引數, 速度快,佔空間少 torch.save(net.state_dict(),PATH)

#--------------------------------------------------

#針對上面一般的儲存方法,載入的方法分別是:

model_dict=torch.load(PATH)

model_dict=model.load_state_dict(torch.load(PATH))

2、然而,在實驗中往往需要儲存更多的資訊,比如優化器的引數,那麼可以採取下面的方法儲存:

 

Model's state_dict:
conv1.weight     torch.Size([6, 3, 5, 5])
conv1.bias   torch.Size([6])
conv2.weight     torch.Size([16, 6, 5, 5])
conv2.bias   torch.Size([16])
fc1.weight   torch.Size([120, 400])
fc1.bias     torch.Size([120])
fc2.weight   torch.Size([84, 120])
fc2.bias     torch.Size([84])
fc3.weight   torch.Size([
10, 84]) fc3.bias torch.Size([10]) Optimizer's state_dict: state {} param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]}]
torch.save({'epoch': epochID + 1, 'state_dict': model.state_dict(), 'best_loss': lossMIN,
                            'optimizer': optimizer.state_dict(),'alpha': loss.alpha, 'gamma': loss.gamma},
                           checkpoint_path + '/m-' + launchTimestamp + '-' + str("%.4f" % lossMIN) + '.pth.tar')