pytorch學習筆記(十):learning rate decay(學習率衰減)
阿新 • • 發佈:2019-01-25
pytorch learning rate decay
本文主要是介紹在pytorch
中如何使用learning rate decay
.
先上程式碼:
def adjust_learning_rate(optimizer, decay_rate=.9):
for param_group in optimizer.param_groups:
param_group['lr'] = param_group['lr'] * decay_rate
什麼是param_groups
?
optimizer
通過param_group
來管理引數組.param_group
中儲存了引數組及其對應的學習率,動量等等.所以我們可以通過更改param_group['lr']
# 有兩個`param_group`即,len(optim.param_groups)==2
optim.SGD([
{'params': model.base.parameters()},
{'params': model.classifier.parameters(), 'lr': 1e-3}
], lr=1e-2, momentum=0.9)
#一個引數組
optim.SGD(model.parameters(), lr=1e-2, momentum=.9)