Pytorch儲存模型用於測試和用於繼續訓練的區別詳解
阿新 • • 發佈:2020-01-13
儲存模型
儲存模型僅僅是為了測試的時候,只需要
torch.save(model.state_dict,path)
path 為儲存的路徑
但是有時候模型及資料太多,難以一次性訓練完的時候,而且用的還是 Adam優化器的時候,一定要儲存好訓練的優化器引數以及epoch
state = { 'model': model.state_dict(),'optimizer':optimizer.state_dict(),'epoch': epoch } torch.save(state,path)
因為這裡
def adjust_learning_rate(optimizer,epoch): lr_t = lr lr_t = lr_t * (0.3 ** (epoch // 2)) for param_group in optimizer.param_groups: param_group['lr'] = lr_t
學習率是根據epoch變化的,如果不儲存epoch的話,基本上每次都從epoch為0開始訓練,這樣學習率就相當於不變了!!
恢復模型
恢復模型只用於測試的時候,
model.load_state_dict(torch.load(path))
path為之前儲存模型時的路徑
但是如果是用於繼續訓練的話,
checkpoint = torch.load(path) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['epoch']+1
依次恢復出模型 優化器引數以及epoch
以上這篇Pytorch儲存模型用於測試和用於繼續訓練的區別詳解就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。