Pytorch學習筆記之優化器的使用
阿新 • • 發佈:2021-01-30
torch.optim提供了不同演算法實現的優化器,在模型訓練時用於更新模型引數。torch.optim.Optimizer為基類,所有的優化器都是該類的子類。優化器使用比較簡單,以torch.optim.SGD為例,
import torch.nn as nn
import torch.optim as optim
model = nn.Sequential(OrderedDict({'linear_1' : nn.Linear(10,30),
'tanh':nn.Tanh(),
'linear_2' : nn.Linear(30,5),
'sigmod': nn.Sigmoid()}))
optimization = optim.SGD(model.parameters(),lr=0.01)
這樣模型每一個引數的學習率都是0.01. 可以通過params_groups屬性或者state_dict()成員函式訪問。
import torch.nn as nn
import torch.optim as optim
print(optimization.params_groups)
print(optimization. state_dict()['params_groups'])
此外還可以為每一個子模組單獨設定學習率。方法則是構造一個字典,並新增‘params’, 'lr’等鍵。也可以新增自己自定義的鍵名。
import torch.nn as nn
import torch.optim as optim
optim_setting = [{'params': model.linear_1.parameters(),'lr':0.01,'name':'linear_1'},
{'params': model.linear_2.parameters(),'lr':0.001, 'name':'linear_2'} ]
optimization = optim.SGD(optim_setting,lr=0.1)
print(model.param_groups)
裡面’name’是自定義的鍵名,並且我們可以看到每一層子模組有了自己的學習率,其他模組則是預設的0.1的學習率。
優化器的儲存和載入也和模型一樣,通過state_dict()和load_state_dict()函式實現。