1. 程式人生 > 其它 >【Pytorch】快速輸出網路模型資訊方法

【Pytorch】快速輸出網路模型資訊方法

技術標籤:PytorchPython深度學習

首先,安裝torchsummary

pip isntall torchsummary

下面是demo程式碼,其中(3, 100, 100)分別為單張圖片的通道數、高、寬

from torchsummary import summary
import torchvision
import torch

def model_info(model):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    backbone = model.to(device)
summary(backbone, (3, 100, 100)) model = torchvision.models.vgg16(pretrained=False) model_info(model)

輸出結果如下,13個Conv2d(卷積層),3個Linear(全連線層),一共16層,vgg16沒毛病

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
==============
================================================== Conv2d-1 [-1, 64, 100, 100] 1,792 ReLU-2 [-1, 64, 100, 100] 0 Conv2d-3 [-1, 64, 100, 100] 36,928 ReLU-4 [-1, 64, 100, 100] 0 MaxPool2d-
5 [-1, 64, 50, 50] 0 Conv2d-6 [-1, 128, 50, 50] 73,856 ReLU-7 [-1, 128, 50, 50] 0 Conv2d-8 [-1, 128, 50, 50] 147,584 ReLU-9 [-1, 128, 50, 50] 0 MaxPool2d-10 [-1, 128, 25, 25] 0 Conv2d-11 [-1, 256, 25, 25] 295,168 ReLU-12 [-1, 256, 25, 25] 0 Conv2d-13 [-1, 256, 25, 25] 590,080 ReLU-14 [-1, 256, 25, 25] 0 Conv2d-15 [-1, 256, 25, 25] 590,080 ReLU-16 [-1, 256, 25, 25] 0 MaxPool2d-17 [-1, 256, 12, 12] 0 Conv2d-18 [-1, 512, 12, 12] 1,180,160 ReLU-19 [-1, 512, 12, 12] 0 Conv2d-20 [-1, 512, 12, 12] 2,359,808 ReLU-21 [-1, 512, 12, 12] 0 Conv2d-22 [-1, 512, 12, 12] 2,359,808 ReLU-23 [-1, 512, 12, 12] 0 MaxPool2d-24 [-1, 512, 6, 6] 0 Conv2d-25 [-1, 512, 6, 6] 2,359,808 ReLU-26 [-1, 512, 6, 6] 0 Conv2d-27 [-1, 512, 6, 6] 2,359,808 ReLU-28 [-1, 512, 6, 6] 0 Conv2d-29 [-1, 512, 6, 6] 2,359,808 ReLU-30 [-1, 512, 6, 6] 0 MaxPool2d-31 [-1, 512, 3, 3] 0 AdaptiveAvgPool2d-32 [-1, 512, 7, 7] 0 Linear-33 [-1, 4096] 102,764,544 ReLU-34 [-1, 4096] 0 Dropout-35 [-1, 4096] 0 Linear-36 [-1, 4096] 16,781,312 ReLU-37 [-1, 4096] 0 Dropout-38 [-1, 4096] 0 Linear-39 [-1, 1000] 4,097,000 ================================================================ Total params: 138,357,544 Trainable params: 138,357,544 Non-trainable params: 0 ---------------------------------------------------------------- Input size (MB): 0.11 Forward/backward pass size (MB): 43.51 Params size (MB): 527.79 Estimated Total Size (MB): 571.42 ----------------------------------------------------------------