1. 程式人生 > >pytorch學習筆記(十):learning rate decay(學習率衰減)

pytorch學習筆記(十):learning rate decay(學習率衰減)

pytorch learning rate decay

本文主要是介紹在pytorch中如何使用learning rate decay.
先上程式碼:


def adjust_learning_rate(optimizer, decay_rate=.9):
    for param_group in optimizer.param_groups:
        param_group['lr'] = param_group['lr'] * decay_rate

什麼是param_groups?
optimizer通過param_group來管理引數組.param_group中儲存了引數組及其對應的學習率,動量等等.所以我們可以通過更改param_group['lr']

的值來更改對應引數組的學習率.

# 有兩個`param_group`即,len(optim.param_groups)==2
optim.SGD([
                {'params': model.base.parameters()},
                {'params': model.classifier.parameters(), 'lr': 1e-3}
            ], lr=1e-2, momentum=0.9)

#一個引數組
optim.SGD(model.parameters(), lr=1e-2, momentum=.9)

參考資料