1. 程式人生 > >pytorch系列 ---9的番外, Xavier和kaiming是如何fan_in和fan_out的,_calculate_fan_in_and_fan_out解讀 Conv2d

pytorch系列 ---9的番外, Xavier和kaiming是如何fan_in和fan_out的,_calculate_fan_in_and_fan_out解讀 Conv2d

本文主要藉助程式碼講解Xavier和kaiming是如何藉助_calculate_fan_in_and_fan_out函式來計算當前網路層的fan_in(輸入神經元個數)和fan_out(輸出神經元個數的),先針對Linear和Conv2d兩種。

在這裡插入圖片描述

m_c = nn.Conv2d(16, 33, 3, stride=2)
m_l = nn.Linear(1, 10)
m_c.weight.size()
m_l.weight.size()

out:

torch.Size([33, 16, 3, 3])
torch.Size([10, 1])

注意看Linear weight的維度為2,而Conv2d的維度為4.
首先判斷tensor的維度,如果是二維,則是Linear,

if dimensions == 2:  # Linear
        fan_in = tensor.size(1)
        fan_out = tensor.size(0)

此時: f a n _ i n

= i n _ c h a n n e
l s fan\_in = in\_channels
f a n _ o u t = o u t _ c h a n n e l s fan\_out = out\_channels

而如果大於2維,Conv2d.weight的第一維為out_channels, 第二維為in_channels第三維和第四維維kernal_size,
程式碼else,先取出前兩個維度,然後tensor[0][0].numel得到的是tensor[0][0]中元素的數目,也就是: w _ c . w e i g h t . s i z e ( 2 ) × w _ c . w e i g h t . s i z e ( 3 ) × × w _ c . w e i g h t . s i z e ( 1 ) w\_c.weight.size(2) \times w\_c.weight.size(3)\times {\ldots}_{\rm } \times w\_c.weight.size(-1) ,在m_c中就是 3 3 = 9 3*3=9
再將此值乘以 num_input_fmapsnum_output_fmaps就得到fan_infan_out

f a n _ i n = i n _ c h a n n e l s k e r n a l _ s i z e [ 0 ] k e r n a l _ s i z e [ 0 ] fan\_in = in\_channels * kernal\_size[0]* kernal\_size[0]
f a n _ o u t = o u t _ c h a n n e l s k e r n a l _ s i z e [ 0 ] k e r n a l _ s i z e [ 0 ] fan\_out = out\_channels * kernal\_size[0]* kernal\_size[0]

else:
        num_input_fmaps = tensor.size(1)
        num_output_fmaps = tensor.size(0)
        receptive_field_size = 1
        if tensor.dim() > 2:
            receptive_field_size = tensor[0][0].numel()
        fan_in = num_input_fmaps * receptive_field_size
        fan_out = num_output_fmaps * receptive_field_size

這是測試程式碼

m_c = nn.Conv2d(16, 33, 3, stride=2)

m_l = nn.Linear(1, 10)

m_c.weight.size()
Out[30]: torch.Size([33, 16, 3, 3])

m_l.weight.size()
Out[31]: torch.Size([10, 1])

m_c.weight[0][0]
Out[32]: 
tensor([[-0.0667,  0.0241,  0.0701],
        [-0.0209,  0.0364,  0.0826],
        [ 0.0803, -0.0535,  0.0316]], grad_fn=<SelectBackward>)

m_c.weight[0][0].numel()
Out[33]: 9