1. 程式人生 > >Pytorch:lr_schedule恢復訓練的注意事項

Pytorch:lr_schedule恢復訓練的注意事項

  在訓練過程中我們一般會使用pytorch已有的學習率調整策略,如:

import torch
import torch.optim as optim
from torchvision.models.resnet import resnet50
net = resnet50(num_classes=1000)
optimizer = optim.Adam(net.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [20, 30, 40, 50], 0.1)
for epoch in range(num_epoch):
    scheduler.step()
    train()
    valid()
    ...

  有時候會因為不可抗拒的外界因素導致訓練被中斷,在pytorch中恢復訓練的方法就是把最近儲存的模型重新載入,然後重新訓練即可。假設我們從epoch10開始恢復訓練,可以利用lr_scheduler的一個引數:

scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [20, 30, 40, 50], 0.1, last_epoch=10)

  這樣就不需要手動地去改[20, 30, 40, 50]->[10, 20, 30, 40] 。需要注意的是在optimizer定義引數組的時候需要加入’initial_lr’,不然會報錯:

"param 'initial_lr' is not specified in param_groups[*] when resuming an optimizer"

舉個粟子:

import torch
import torch.optim as optim
from torchvision.models.resnet import resnet50
net = torch.load('resnet50_epoch10.pth')
optimizer = optim.Adam([{'params': net.parameters(), 'initial_lr': 1e-3}], lr=1e-3)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [20, 30, 40, 50], 0.1, last_epoch=10)
for epoch in range(11, num_epoch):
    scheduler.step()
    train()
    valid()
    ...