pytorch 實現列印模型的引數值
阿新 • • 發佈:2020-01-09
對於簡單的網路
例如全連線層Linear
可以使用以下方法列印linear層:
fc = nn.Linear(3,5) params = list(fc.named_parameters()) print(params.__len__()) print(params[0]) print(params[1])
輸出如下:
由於Linear預設是偏置bias的,所有引數列表的長度是2。第一個存的是全連線矩陣,第二個存的是偏置。
對於稍微複雜的網路
例如MLP
mlp = nn.Sequential( nn.Dropout(p=0.3),nn.Linear(1024,256),nn.Linear(256,64),nn.Linear(64,16),nn.Linear(16,1) ) params = list(mlp.named_parameters()) print(params.__len__()) print(params[0]) print(params[1]) print(params[2]) print(params[3])
輸出:
可以發現,堆疊起來的網路,引數是依次放置的。先是全連線的權重,然後偏置。然後是下一層網路的權重+偏置。依次進行下去。
這裡有4層fc,4*2=8.所以一共有8個引數矩陣。
以上這篇pytorch 實現列印模型的引數值就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。