1. 程式人生 > 其它 >pytorch獲得模型的引數資訊,所佔記憶體的大小

pytorch獲得模型的引數資訊,所佔記憶體的大小

一 sum

一個模型所佔的視訊記憶體無非是這兩種:

  • 模型權重引數
  • 模型所儲存的中間變數

其實權重引數一般來說並不會佔用很多的視訊記憶體空間,主要佔用視訊記憶體空間的還是計算時產生的中間變數,當我們定義了一個model之後,我們可以通過以下程式碼簡單計算出這個模型權重引數所佔用的資料量:

import numpy as np

# model是我們在pytorch定義的神經網路層
# model.parameters()取出這個model所有的權重引數
para = sum([np.prod(list(p.size())) for p in model.parameters()])
#
下面的type_size是4,因為我們的引數是float32也就是4B,4個位元組 print('Model {} : params: {:4f}M'.format(model._get_name(), para * type_size / 1000 / 1000))

對上述含義的說明:https://oldpan.me/archives/how-to-use-memory-pytorch

二torchsummary

1.pip install torchsummary安裝

2.

import torch
from torchsummary import summary

# 需要使用device來指定網路在GPU還是CPU執行
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') netG_A2B = Generator(3, 3).to(device) # 這裡需要網路,Generator報錯 summary(netG_A2B, input_size=(3, 256, 256))

Total params這三項比較好理解,因為有可能固定param。
input size也比較好理解:3*256*256/1024/1024*4=0.75(最後一個4表示儲存是需要4位元組,float32型別)
Params size也比較好計算:138,357,544/1024/1024*4=527.79
Forward/backward pass size (MB)的計算:(10*24*24+20*8*8+20*8*8+50+10)/1024/1024*4*2=0.064

(注意最有還有個2)
https://blog.csdn.net/csdnxiekai/article/details/110517751

三pytorch-model-summary

1.pip install pytorch-model-summary

2.

# show input shape
print(summary(Net(), torch.zeros((1, 1, 28, 28)), show_input=True))

# show output shape
print(summary(Net(), torch.zeros((1, 1, 28, 28)), show_input=False))

# show output shape and hierarchical view of net
print(summary(Net(), torch.zeros((1, 1, 28, 28)), show_input=False, show_hierarchical=True))
原文連結:https://blog.csdn.net/csdnxiekai/article/details/110517751