1. 程式人生 > >pytorch中檢視可訓練引數

pytorch中檢視可訓練引數

  pytorch中我們有時候可能需要設定某些變數是參與訓練的,這時候就需要檢視哪些是可訓練引數,以確定這些設定是成功的。

  pytorch中model.parameters()函式定義如下:

    def parameters(self):
        r"""Returns an iterator over module parameters.

        This is typically passed to an optimizer.

        Yields:
            Parameter: module parameter

        Example::

            >>> for param in model.parameters():
            >>>     print(type(param.data), param.size())
            <class 'torch.FloatTensor'> (20L,)
            <class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)

        """
        for name, param in self.named_parameters():
            yield param

所以,我們可以遍歷named_parameters()中的所有的引數,只打印那些param.requires_grad=True的變數。具體實現程式碼如下所示:

for name, param in model.named_parameters():
    if param.requires_grad:
        print(name)

這樣打印出的結果就是模型中所有的可訓練引數列表!