1. 程式人生 > 程式設計 >pytorch 實現列印模型的引數值

pytorch 實現列印模型的引數值

對於簡單的網路

例如全連線層Linear

可以使用以下方法列印linear層:

fc = nn.Linear(3,5)
params = list(fc.named_parameters())
print(params.__len__())
print(params[0])
print(params[1])

輸出如下:

由於Linear預設是偏置bias的,所有引數列表的長度是2。第一個存的是全連線矩陣,第二個存的是偏置。

對於稍微複雜的網路

例如MLP

mlp = nn.Sequential(
      nn.Dropout(p=0.3),nn.Linear(1024,256),nn.Linear(256,64),nn.Linear(64,16),nn.Linear(16,1)
    )
params = list(mlp.named_parameters())
print(params.__len__())

print(params[0])
print(params[1])

print(params[2])
print(params[3])

輸出:

可以發現,堆疊起來的網路,引數是依次放置的。先是全連線的權重,然後偏置。然後是下一層網路的權重+偏置。依次進行下去。

這裡有4層fc,4*2=8.所以一共有8個引數矩陣。

以上這篇pytorch 實現列印模型的引數值就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。