卷積引數量和計算量(FLOPs)的計算公式及程式碼
我們經常用引數量和浮點計算數FLOPs來衡量卷積網路的複雜度。下面推導其公式並在pytorch中實現,以二維卷積Conv2d為例。
1、公式
以下公式適用於各種情況的卷積層,如普通卷積、膨脹卷積、分組卷積、可分離卷積等:
引數量:
無bias時:
k
H
×
k
W
×
C
i
n
/
g
×
C
o
u
t
k_{H }×k_{W }×C_{in}/g×C_{out}
kH×kW×Cin/g×Cout
有bias時:
(
k
H
×
k
W
×
C
i
n
/
g
+
1
)
×
C
o
u
t
(k_{H }×k_{W }×C_{in}/g + 1)×C_{out}
浮點計算數FLOPs(指所有的乘法和加法運算):
無bias時:
(
2
×
k
H
×
k
W
×
C
i
n
/
g
−
1
)
×
C
o
u
t
×
H
o
u
t
×
W
o
u
t
(2×k_{H }×k_{W }×C_{in}/g - 1)×C_{out}×H_{out}×W_{out}
(2×kH×kW×Cin/g
有bias時:
2
×
k
H
×
k
W
×
C
i
n
/
g
×
C
o
u
t
×
H
o
u
t
×
W
o
u
t
2×k_{H }×k_{W }×C_{in}/g×C_{out}×H_{out}×W_{out}
2×kH×kW×Cin/g×Cout×Hout×Wout
式中各量對應的pytorch卷積層的配置引數如下:
對於可分離卷積,網上有的文章也單拎出來對其公式進行了推導,其實可分離卷積就是一個帶groups引數的分組卷積加上一個1*1全卷積,不需要額外再推導,兩部分都用上式計算之後再加起來就可以了。
推導過程也很簡單。根據卷積的原理,
(1)卷積核的大小是
k
H
×
k
W
×
C
i
n
k_{H }×k_{W }×C_{in}
kH×kW×Cin,卷積核的數量就是輸出通道數
C
o
u
t
C_{out }
Cout,無bias時,所有卷積核的元素數量就是引數量,把前面兩個乘起來就可以了。
(2)有bias時每個卷積核對應一個bias,所以在每個卷積核的引數量上再加1
(3)卷積計算時,無bias時,所有卷積核元素和對應位置的輸入特徵圖張量一一相乘再相加,乘法次數就是卷積核元素數,注意相加的時候是逐個往第一個數上加的所以總加法次數是卷積核元素數減1,就是
(
2
×
k
H
×
k
W
×
C
i
n
/
g
−
1
)
(2×k_{H }×k_{W }×C_{in}/g - 1)
(2×kH×kW×Cin/g−1);而卷積核在一個位置的運算只能產生一個輸出資料點,隨著卷積核的平移(膨脹卷積的跳點平移也是一樣)共產生了
C
o
u
t
×
H
o
u
t
×
W
o
u
t
C_{out}×H_{out}×W_{out}
Cout×Hout×Wout個輸出資料點,把這兩部分乘起來就是總計算量。
(4)有bias時,每個輸出資料點還要再加上bias,多一次加法,正好把上式的減1抵消掉了。
(5)分組卷積時,根據分組卷積原理,相當於進行了g個輸入通道為
C
i
n
/
g
C_{in}/g
Cin/g,輸出通道為
C
o
u
t
/
g
C_{out}/g
Cout/g的卷積,所以把上面所有式的
C
i
n
C_{in}
Cin,
C
o
u
t
C_{out}
Cout替換為
C
i
n
/
g
C_{in}/g
Cin/g,
C
o
u
t
/
g
C_{out}/g
Cout/g,再乘以g後抵消掉
C
o
u
t
/
g
C_{out}/g
Cout/g的分母,就得到了最終公式。
2、程式碼
如果只計算所有卷積層的引數量,可以不用上述公式,直接從model中提取即可
total_para_nums = 0
for n,m in model.named_modules():
if isinstance(m,nn.Conv2d):
total_para_nums += m.weight.data.numel()
if m.bias is not None:
total_para_nums += m.bias.data.numel()
print('total parameters:',total_para_nums)
但FLOPs計算涉及到中間層特徵圖的尺寸,不能從model中得到,必須使用鉤子hook操作。程式碼如下,可同時計算引數量,逐層顯示所有Conv2d層的引數量和FLOPs以及總量:
model = XXX_Net()#注:以下程式碼放在模型例項化之後,模型名用model
def my_hook(Module, input, output):
outshapes.append(output.shape)
modules.append(Module)
names,modules,outshapes = [],[],[]
for name,m in model.named_modules():
if isinstance(m,nn.Conv2d):
m.register_forward_hook(my_hook)
names.append(name)
def calc_paras_flops(modules,outshapes):
total_para_nums = 0
total_flops = 0
for i,m in enumerate(modules):
Cin = m.in_channels
Cout = m.out_channels
k = m.kernel_size
#p = m.padding
#s = m.stride
#d = m.dilation
g = m.groups
Hout = outshapes[i][2]
Wout = outshapes[i][3]
if m.bias is None:
para_nums = k[0] * k[1] * Cin / g * Cout
flops = (2 * k[0] * k[1] * Cin/g - 1) * Cout * Hout * Wout
else:
para_nums = (k[0] * k[1] * Cin / g +1) * Cout
flops = 2 * k[0] * k[1] * Cin/g * Cout * Hout * Wout
para_nums = int(para_nums)
flops = int(flops)
print(names[i], 'para:', para_nums, 'flops:',flops)
total_para_nums += para_nums
total_flops += flops
print('total parameters:',total_para_nums, 'total FLOPs:',total_flops)
return total_para_nums, total_flops
input = torch.rand(32,3,224,224)#需要先提供一個輸入張量
y = model(input)
total_para_nums, total_flops = calc_paras_flops(modules,outshapes)