1. 程式人生 > 其它 >Pytorch學習筆記之優化器的使用

Pytorch學習筆記之優化器的使用

技術標籤:PytorchPytorch

torch.optim提供了不同演算法實現的優化器,在模型訓練時用於更新模型引數。torch.optim.Optimizer為基類,所有的優化器都是該類的子類。優化器使用比較簡單,以torch.optim.SGD為例,

import torch.nn as nn
import torch.optim as optim
model = nn.Sequential(OrderedDict({'linear_1' : nn.Linear(10,30),
                            'tanh':nn.Tanh(),
                            'linear_2'
: nn.Linear(30,5), 'sigmod': nn.Sigmoid()})) optimization = optim.SGD(model.parameters(),lr=0.01)

這樣模型每一個引數的學習率都是0.01. 可以通過params_groups屬性或者state_dict()成員函式訪問。

import torch.nn as nn
import torch.optim as optim

print(optimization.params_groups)
print(optimization.
state_dict()['params_groups'])

此外還可以為每一個子模組單獨設定學習率。方法則是構造一個字典,並新增‘params’, 'lr’等鍵。也可以新增自己自定義的鍵名。

import torch.nn as nn
import torch.optim as optim

optim_setting = [{'params': model.linear_1.parameters(),'lr':0.01,'name':'linear_1'},
                    {'params': model.linear_2.parameters(),'lr':0.001,
'name':'linear_2'} ] optimization = optim.SGD(optim_setting,lr=0.1) print(model.param_groups)

裡面’name’是自定義的鍵名,並且我們可以看到每一層子模組有了自己的學習率,其他模組則是預設的0.1的學習率。

優化器的儲存和載入也和模型一樣,通過state_dict()和load_state_dict()函式實現。