pytorch 模型引數量、FLOPs統計方法
阿新 • • 發佈:2021-02-10
技術標籤:pytorch
文章目錄
一、使用第三方工具:
torchstat:
安裝:pip install torchstat
torchstat GitHub 原始碼頁面
例子:
from torchstat import stat
from backbone import EfficientDetBackbone
model = EfficientDetBackbone()
stat(model, (3, 1280, 1280))
輸出:會輸出模型各層網路的資訊,最後進行總結統計。
ptflops:
安裝:pip install ptflops
ptflops GithHub原始碼頁面
例子:
import torchvision.models as models
import torch
from ptflops import get_model_complexity_info
with torch.cuda.device(0):
net = models.densenet161()
macs, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True,
print_per_layer_stat= True, verbose=True)
print('{:<30} {:<8}'.format('Computational complexity: ', macs))
print('{:<30} {:<8}'.format('Number of parameters: ', params))
輸出:同樣會輸出模型各層的資訊,最後總結統計 引數量 和 FLOPs。
注意: 使用第三方工具時, 網路中有些層可能會不支援計算。
其他工具:
- torchsummary
- thop
二、使用函式統計模型引數量:
計算模型引數量 與 可訓練引數量:
def get_parameter_number (model):
total_num = sum(p.numel() for p in model.parameters())
trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
return {'Total': total_num, 'Trainable': trainable_num}
result = get_parameter_number(model)
print(result['Total'],result['Trainable']) #列印引數量