python 統計網路的運算量FLOP 和 引數量params
阿新 • • 發佈:2018-12-08
#-*- coding: utf-8 -*- import os import sys import shutil import struct from google.protobuf import text_format import caffe from caffe.proto import caffe_pb2 LAYER_PARAM = {'Convolution', 'InnerProduct'} class CalFlop(): def __init__(self, model, deploy): self.model = model self.deploy = deploy self.net = caffe.Net(deploy, model, caffe.TEST) self.transformer = caffe.io.Transformer({'data': self.net.blobs['data'].data.shape}) self.transformer.set_transpose('data', (2,0,1)) self.netlist = caffe_pb2.NetParameter() text_format.Merge(open(deploy).read(), self.netlist) def GetLayerList(self): LayerList = list() for layername in self.netlist.layer: LayerList.append(layername.name) print layername return LayerList def CalFlops(self): LayerList = self.GetLayerList() ALL_FLOPS = 0 for Layer in LayerList: idx = LayerList.index(Layer) layerparam = self.netlist.layer._values[idx] if layerparam.type in LAYER_PARAM: H = self.net._blobs_dict[Layer].height W = self.net._blobs_dict[Layer].width blobs = self.net.params[layerparam.name] batch = blobs[0].num chns = blobs[0].channels kh = blobs[0].height kw = blobs[0].width FLOPS = batch * chns * kh * kw * H * W ALL_FLOPS += FLOPS # print "{} FLOPS is {}".format(Layer, FLOPS) print "Net FLOPS is {}".format(ALL_FLOPS) def CalParams(self): params = 0 for layername in self.netlist.layer: if layername.type == 'Convolution': botName = layername.bottom[0] C = self.net._blobs_dict[botName].channels chns = layername.convolution_param.num_output kw = layername.convolution_param.kernel_size[0] kh = layername.convolution_param.kernel_size[0] params = params + kw * kh * chns * C if layername.type == 'InnerProduct': botName = layername.bottom[0] H = self.net._blobs_dict[botName].height W = self.net._blobs_dict[botName].width C = self.net._blobs_dict[botName].channels params = params + layername.inner_product_param.num_output * H * W * C print "params is {}".format(params) if __name__ =='__main__': MODEL_FILE = r'/home/ssd/deploy.prototxt' PRETRAINED = r'/home/ssd/VGG_SSD_300x300.caffemodel' trans = CalFlop(PRETRAINED, MODEL_FILE) trans.CalFlops() trans.CalParams()
寫了個統計模型運算量 和 餐數量的指令碼,主要用於模型優化後更加直觀些。VGG 可用 ,VGG-SSD 可用,沒有統計更多,可能有bug ,希望支出。
轉載請名出處。