1. 程式人生 > 其它 >pytorch 模型引數量、FLOPs統計方法

pytorch 模型引數量、FLOPs統計方法

技術標籤:pytorch

文章目錄

一、使用第三方工具:

torchstat:

安裝:pip install torchstat
torchstat GitHub 原始碼頁面

例子:

from torchstat import stat
from backbone import EfficientDetBackbone
model = EfficientDetBackbone()
stat(model, (3, 1280, 1280))

輸出:會輸出模型各層網路的資訊,最後進行總結統計。

在這裡插入圖片描述
在這裡插入圖片描述

ptflops:

安裝:pip install ptflops
ptflops GithHub原始碼頁面

例子:

import torchvision.models as models
import torch
from ptflops import get_model_complexity_info

with torch.cuda.device(0):
  net = models.densenet161()
  macs, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True,
                                           print_per_layer_stat=
True, verbose=True) print('{:<30} {:<8}'.format('Computational complexity: ', macs)) print('{:<30} {:<8}'.format('Number of parameters: ', params))

輸出:同樣會輸出模型各層的資訊,最後總結統計 引數量 和 FLOPs。

注意: 使用第三方工具時, 網路中有些層可能會不支援計算。

其他工具:

  • torchsummary
  • thop

二、使用函式統計模型引數量:

計算模型引數量 與 可訓練引數量:

def get_parameter_number
(model): total_num = sum(p.numel() for p in model.parameters()) trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad) return {'Total': total_num, 'Trainable': trainable_num} result = get_parameter_number(model) print(result['Total'],result['Trainable']) #列印引數量