pytorch 求網路模型引數例項
阿新 • • 發佈:2020-01-09
用pytorch訓練一個神經網路時,我們通常會很關心模型的引數總量。下面分別介紹來兩種方法求模型引數
一 .求得每一層的模型引數,然後自然的可以計算出總的引數。
1.先初始化一個網路模型model
比如我這裡是 model=cliqueNet(裡面是些初始化的引數)
2.呼叫model的Parameters類獲取引數列表
一個典型的操作就是將引數列表傳入優化器裡。如下
optimizer = optim.Adam(model.parameters(),lr=opt.lr)
言歸正傳,繼續回到引數裡面,引數在網路裡面就是variable,下面分別求每層的尺寸大小和個數。
函式get_number_of_param( ) 裡面的引數就是剛才第一步初始化的model
def get_number_of_param(model): """get the number of param for every element""" count = 0 for param in model.parameters(): param_size = param.size() count_of_one_param = 1 for dis in param_size: count_of_one_param *= dis print(param.size(),count_of_one_param) count += count_of_one_param print(count) print('total number of the model is %d'%count)
再來看看結果:
torch.Size([64,1,3,3]) 576 576 torch.Size([64]) 64 640 torch.Size([6,36,64,3]) 124416 125056 torch.Size([30,3]) 349920 474976 torch.Size([12,36]) 432 475408 torch.Size([6,216,3]) 419904 895312 torch.Size([30,3]) 349920 1245232 torch.Size([12,36]) 432 1245664 torch.Size([6,3]) 419904 1665568 torch.Size([30,3]) 349920 2015488 torch.Size([12,36]) 432 2015920 torch.Size([6,3]) 419904 2435824 torch.Size([30,3]) 349920 2785744 torch.Size([12,36]) 432 2786176 torch.Size([216,1]) 46656 2832832 torch.Size([216]) 216 2833048 torch.Size([108,216]) 23328 2856376 torch.Size([108]) 108 2856484 torch.Size([216,108]) 23328 2879812 torch.Size([216]) 216 2880028 torch.Size([216,1]) 46656 2926684 torch.Size([216]) 216 2926900 torch.Size([108,216]) 23328 2950228 torch.Size([108]) 108 2950336 torch.Size([216,108]) 23328 2973664 torch.Size([216]) 216 2973880 torch.Size([216,1]) 46656 3020536 torch.Size([216]) 216 3020752 torch.Size([108,216]) 23328 3044080 torch.Size([108]) 108 3044188 torch.Size([216,108]) 23328 3067516 torch.Size([216]) 216 3067732 torch.Size([140,280,1]) 39200 3106932 torch.Size([140]) 140 3107072 torch.Size([216,432,1]) 93312 3200384 torch.Size([216]) 216 3200600 torch.Size([216,1]) 93312 3293912 torch.Size([216]) 216 3294128 torch.Size([9,572,3]) 46332 3340460 torch.Size([9]) 9 3340469 total number of the model is 3340469
可以通過計算驗證一下,發現引數與網路是一致的。
二:一行程式碼就可以搞定引數總個數問題
2.1 先來看看torch.tensor.numel( )這個函式的功能就是求tensor中的元素個數,在網路裡面每層引數就是多維陣列組成的tensor。
實際上就是求多維陣列的元素個數。看程式碼。
print('cliqueNet parameters:',sum(param.numel() for param in model.parameters()))
當然上面程式碼中的model還是上面初始化的網路模型。
看看兩種的計算結果
torch.Size([64,3]) 46332 3340460 torch.Size([9]) 9 3340469 total number of the model is 3340469 cliqueNet parameters: 3340469
可以看出兩種計算出來的是一模一樣的。
以上這篇pytorch 求網路模型引數例項就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。