pytorch中檢視可訓練引數
阿新 • • 發佈:2018-12-15
pytorch中我們有時候可能需要設定某些變數是參與訓練的,這時候就需要檢視哪些是可訓練引數,以確定這些設定是成功的。
pytorch中model.parameters()
函式定義如下:
def parameters(self): r"""Returns an iterator over module parameters. This is typically passed to an optimizer. Yields: Parameter: module parameter Example:: >>> for param in model.parameters(): >>> print(type(param.data), param.size()) <class 'torch.FloatTensor'> (20L,) <class 'torch.FloatTensor'> (20L, 1L, 5L, 5L) """ for name, param in self.named_parameters(): yield param
所以,我們可以遍歷named_parameters()
中的所有的引數,只打印那些param.requires_grad=True
的變數。具體實現程式碼如下所示:
for name, param in model.named_parameters():
if param.requires_grad:
print(name)
這樣打印出的結果就是模型中所有的可訓練引數列表!