1. 程式人生 > 其它 >計算模型FLOPs和引數量

計算模型FLOPs和引數量

pytorch環境下,有兩個計算FLOPs和引數量的包thop和ptflops,結果基本是一致的。

thop

參考https://github.com/Lyken17/pytorch-OpCounter

安裝方法:pip install thop

使用方法:

from torchvision.models import resnet18
from thop import profile
model = resnet18()
input = torch.randn(1, 3, 224, 224) #模型輸入的形狀,batch_size=1
flops, params = profile(model, inputs=(input, ))
print(flops/1e9,params/1e6) #flops單位G,para單位M

用來測試3d resnet18的FLOPs:

model =C3D_Hash_Model(48)
input = torch.randn(1, 3,10, 112, 112) #視訊取10幀
flops, params = profile(model, inputs=(input, ))
print(flops/1e9,params/1e6)

ptflops
參考https://github.com/sovrasov/flops-counter.pytorch

安裝方法:pip install ptflops

或者 pip install git+https://github.com/sovrasov/flops-counter.pytorch.git

import torchvision.models as models
import torch
from ptflops import get_model_complexity_info

with torch.cuda.device(0):
  net = models.resnet18()
  flops, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True, 		                        print_per_layer_stat=True) #不用寫batch_size大小,預設batch_size=1
  print('Flops:  ' + flops)
  print('Params: ' + params)

用來測試3d resnet18的FLOPs:

import torch
from ptflops.flops_counter import get_model_complexity_info
with torch.cuda.device(0):
    net = C3D_Hash_Model(48)
    flops, params = get_model_complexity_info(net, (3,10, 112, 112), as_strings=True, 	                print_per_layer_stat=True)
    print('Flops:  ' + flops)
    print('Params: ' + params)

如果安裝ptflops出問題,可以直接到https://github.com/sovrasov/flops-counter.pytorch.git下載程式碼,然後直接把目錄ptflops複製到專案程式碼中,通過from ptflops.flops_counter import get_model_complexity_info來呼叫函式計算FLOPs

python計時程式執行時間

import time

time_start=time.time()
#在這裡執行模型
time_end=time.time()
print('totally cost',time_end-time_start)