1. 程式人生 > >(原)Non-local Neural Networks

(原)Non-local Neural Networks

轉載請註明出處:

https://www.cnblogs.com/darkknightzh/p/12592351.html

論文:

https://arxiv.org/abs/1711.07971

第三方pytorch程式碼:

https://github.com/AlexHex7/Non-local_pytorch

1. non local操作

該論文定義了通用了non local操作:

${{\mathbf{y}}_{i}}=\frac{1}{C(\mathbf{x})}\sum\limits_{\forall j}{f({{\mathbf{x}}_{i}},{{\mathbf{x}}_{j}})g({{\mathbf{x}}_{j}})}$

其中i為需要計算響應的輸出位置的索引,j為所有的位置。x為輸入訊號(影象,序列,視訊等,通常為這些訊號的特徵),y為個x相同尺寸的輸出訊號。f為pairwise的函式,f計算當前i和所有j之間的關係,並得到一個標量。一元函式g計算輸入訊號在位置j的表徵。(這段翻譯起來怪怪的)。C(x)為歸一化係數,用於歸一化f和g的結果。

2. non local和其他操作的區別

① non local考慮到了所有的位置j。卷積操作僅考慮了當前位置的一個鄰域(如核為3的一維卷積僅考慮了i-1<=j<=i+1);迴圈操作通常只考慮當前和上一個時間,j=i或j=i-1.

② non local根據不同位置的關係計算響應,fc使用學習到的權重。換言之,fc中,${{\mathbf{x}}_{i}}$和${{\mathbf{x}}_{j}}$之間不是函式關係,而non local中則是函式關係。

③ non local支援輸入不同尺寸,並且保持輸出和輸入相同的尺寸;fc則需要輸入和輸出均為固定的尺寸,並且丟失了位置關係。

④ non local可以用在網路的早期部分,fc通常用在網路最後。

 

3. f和g的形式

3.1 g的形式

為簡單起見,只考慮g為線性形式,$g({{\mathbf{x}}_{j}})\text{=}{{W}_{g}}{{\mathbf{x}}_{j}}$,${{W}_{g}}$為需要學習的權重向量,在空域可以使用1*1conv實現,在空間時間域(如時間序列的影象)可以通過1*1*1的卷積實現。

3.2 f為gaussian

$f({{\mathbf{x}}_{i}},{{\mathbf{x}}_{j}})\text{=}{{e}^{\mathbf{x}_{i}^{T}{{\mathbf{x}}_{j}}}}$

其中$\mathbf{x}_{i}^{T}{{\mathbf{x}}_{j}}$為點乘,因為點乘在深度學習平臺中更易實現(歐式距離也可以)。此時歸一化係數$C(\mathbf{x})=\sum\nolimits_{\forall j}{f({{\mathbf{x}}_{i}},{{\mathbf{x}}_{j}})}$

3.3 f為embedded Gaussian

$f({{\mathbf{x}}_{i}},{{\mathbf{x}}_{j}})\text{=}{{e}^{\theta {{({{\mathbf{x}}_{i}})}^{T}}\phi ({{\mathbf{x}}_{j}})}}$

其中$\theta ({{\mathbf{x}}_{i}})\text{=}{{W}_{\theta }}{{\mathbf{x}}_{i}}$,$\phi ({{\mathbf{x}}_{j}})\text{=}{{W}_{\phi }}{{\mathbf{x}}_{j}}$,此時$C(\mathbf{x})=\sum\nolimits_{\forall j}{f({{\mathbf{x}}_{i}},{{\mathbf{x}}_{j}})}$

self attention模組和non local的關係:可以認為self attention為embedded Gaussian的特殊形式,如給定i,$\frac{1}{C(\mathbf{x})}f({{\mathbf{x}}_{i}},{{\mathbf{x}}_{j}})$沿著j維度變成了計算softmax。此時$\mathbf{y}=softmax({{\mathbf{x}}^{T}}W_{\theta }^{T}{{W}_{\phi }}\mathbf{x})g(\mathbf{x})$,即為self attention的形式。

3.4 點乘

f可以定義為點乘的相似度(此處使用embedded的形式):

$f({{\mathbf{x}}_{i}},{{\mathbf{x}}_{j}})\text{=}\theta {{({{\mathbf{x}}_{i}})}^{T}}\phi ({{\mathbf{x}}_{j}})$

此時,歸一化係數$C(\mathbf{x})=N$,N為x中所有位置的數量,而不是f的sum,這樣可以簡化梯度的計算。

點乘和embedded Gaussian的區別是是否使用了作為啟用函式的softmax。

3.5 Concatenation

$f({{\mathbf{x}}_{i}},{{\mathbf{x}}_{j}})\text{=ReLU(w}_{f}^{T}[\theta ({{\mathbf{x}}_{i}}),\phi ({{\mathbf{x}}_{j}})]\text{)}$

其中$[\cdot \cdot ]$代表concatenation,即拼接。${{w}_{f}}$為權重向量,用於將拼接後的向量對映到一個標量。$C(\mathbf{x})=N$

4. Non local block

將之前公式的non local操作擴充套件成non local block,可以嵌入到目前的網路結構中,如下:

${{\mathbf{z}}_{i}}={{W}_{z}}{{\mathbf{y}}_{i}}+{{\mathbf{x}}_{i}}$

其中${{\mathbf{y}}_{i}}=\frac{1}{C(\mathbf{x})}\sum\limits_{\forall j}{f({{\mathbf{x}}_{i}},{{\mathbf{x}}_{j}})g({{\mathbf{x}}_{j}})}$,$+{{\mathbf{x}}_{i}}$代表殘差連線。殘差連線方便將non local block嵌入到之前與訓練的模型中,避免打亂其初始行為(如將${{W}_{z}}$初始化為0)。

non local block如下圖所示。3.2,3.3,3.4中的pairwise計算對應於下圖中的矩陣乘法。在網路後面的特徵圖上,pairwise計算量比較小。

說明:

1. 若為影象,則使用1*1conv,且圖中無T;若為視訊,則使用1*1*1conv,且圖中有T。

2. 圖中softmax指對該矩陣每行計算softmax。

5. 降低計算量

5.1 降低x的通道數量

將${{W}_{g}}$,${{W}_{\theta }}$,${{W}_{\phi }}$降低為x的通道數量的一半,可以降低計算量。

5.2 對x下采樣。

對x下采樣,可以進一步降低計算量。

此時,1中的共識修改為${{\mathbf{y}}_{i}}=\frac{1}{C(\mathbf{\hat{x}})}\sum\limits_{\forall j}{f({{\mathbf{x}}_{i}},{{{\mathbf{\hat{x}}}}_{j}})g({{{\mathbf{\hat{x}}}}_{j}})}$,其中$\mathbf{\hat{x}}$為對x進行下采樣後的輸入(如pooling)。這種方式可以降低pariwsie計算到原來的1/4,一方面不影響non local的行為,另一方面,使得計算更加稀疏。可以通過在上圖中$\phi $和$g$後面加一個max pooling來實現。

6. 程式碼:

6.1 embedded_gaussian

  1 class _NonLocalBlockND(nn.Module):
  2     def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
  3         """
  4         :param in_channels:
  5         :param inter_channels:
  6         :param dimension:
  7         :param sub_sample:
  8         :param bn_layer:
  9         """
 10 
 11         super(_NonLocalBlockND, self).__init__()
 12 
 13         assert dimension in [1, 2, 3]
 14 
 15         self.dimension = dimension
 16         self.sub_sample = sub_sample
 17 
 18         self.in_channels = in_channels
 19         self.inter_channels = inter_channels
 20 
 21         if self.inter_channels is None:
 22             self.inter_channels = in_channels // 2
 23             if self.inter_channels == 0:
 24                 self.inter_channels = 1
 25 
 26         if dimension == 3:
 27             conv_nd = nn.Conv3d
 28             max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
 29             bn = nn.BatchNorm3d
 30         elif dimension == 2:
 31             conv_nd = nn.Conv2d
 32             max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
 33             bn = nn.BatchNorm2d
 34         else:
 35             conv_nd = nn.Conv1d
 36             max_pool_layer = nn.MaxPool1d(kernel_size=(2))
 37             bn = nn.BatchNorm1d
 38 
 39         self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
 40                          kernel_size=1, stride=1, padding=0)    # g函式,1*1conv,用於降維
 41 
 42         if bn_layer:
 43             self.W = nn.Sequential(    # 1*1conv,用於圖2中變換到原始維度
 44                 conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
 45                         kernel_size=1, stride=1, padding=0),
 46                 bn(self.in_channels)
 47             )
 48             nn.init.constant_(self.W[1].weight, 0)
 49             nn.init.constant_(self.W[1].bias, 0)
 50         else:
 51             self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
 52                              kernel_size=1, stride=1, padding=0)   # 1*1conv,用於圖2中變換到原始維度
 53             nn.init.constant_(self.W.weight, 0)
 54             nn.init.constant_(self.W.bias, 0)
 55 
 56         self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
 57                              kernel_size=1, stride=1, padding=0)   # θ函式,1*1conv,用於降維
 58         self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
 59                            kernel_size=1, stride=1, padding=0)    # φ函式,1*1conv,用於降維
 60 
 61         if sub_sample:
 62             self.g = nn.Sequential(self.g, max_pool_layer)
 63             self.phi = nn.Sequential(self.phi, max_pool_layer)
 64 
 65     def forward(self, x, return_nl_map=False): 
 66         """
 67         :param x: (b, c, t, h, w)
 68         :param return_nl_map: if True return z, nl_map, else only return z.
 69         :return:
 70         """
 71         # 令x維度B*C*(K):一維時,x為B*C*(K1);二維時,x為B*C*(K1*K2);三維時,x為B*C*(K1*K2*K3)
 72         batch_size = x.size(0)   # batchsize
 73 
 74         g_x = self.g(x).view(batch_size, self.inter_channels, -1)   # 通過g函式,並reshape,得到B*inter_channels*(K)矩陣
 75         g_x = g_x.permute(0, 2, 1)   # 得到B*(K)*inter_channels矩陣,和圖2中一致
 76 
 77         theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)   # 通過θ函式,並reshape,得到B*inter_channels*(K)矩陣
 78         theta_x = theta_x.permute(0, 2, 1)   # 得到B*(K)*inter_channels矩陣,和圖2中一致
 79         phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)   # 通過φ函式,並reshape,得到B*inter_channels*(K)矩陣
 80         f = torch.matmul(theta_x, phi_x)    # 得到B*(K)*(K)矩陣,和圖2中一致
 81         f_div_C = F.softmax(f, dim=-1)      # 通過softmax,對最後一維歸一化,得到歸一化的特徵,即概率,B*(K)*(K)
 82 
 83         y = torch.matmul(f_div_C, g_x)      # 得到B*(K)*inter_channels矩陣,和圖2中一致
 84         y = y.permute(0, 2, 1).contiguous() # 得到B*inter_channels*(K)矩陣,和圖2中一致
 85         y = y.view(batch_size, self.inter_channels, *x.size()[2:])  # 得到B*inter_channels*(K1或K1*K2或K1*K2*K3)矩陣,和圖2中一致
 86         W_y = self.W(y)  # 得到B*C*(K)矩陣,和圖2中一致
 87         z = W_y + x   # 特徵圖和non local的圖相加,得到新的特徵圖,B*C*(K)
 88 
 89         if return_nl_map:  
 90             return z, f_div_C   # 返回結果及歸一化的特徵
 91         return z
 92 
 93 
 94 class NONLocalBlock1D(_NonLocalBlockND):
 95     def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
 96         super(NONLocalBlock1D, self).__init__(in_channels,
 97                                               inter_channels=inter_channels,
 98                                               dimension=1, sub_sample=sub_sample,
 99                                               bn_layer=bn_layer)
100 
101 
102 class NONLocalBlock2D(_NonLocalBlockND):
103     def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
104         super(NONLocalBlock2D, self).__init__(in_channels,
105                                               inter_channels=inter_channels,
106                                               dimension=2, sub_sample=sub_sample,
107                                               bn_layer=bn_layer,)
108 
109 
110 class NONLocalBlock3D(_NonLocalBlockND):
111     def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
112         super(NONLocalBlock3D, self).__init__(in_channels,
113                                               inter_channels=inter_channels,
114                                               dimension=3, sub_sample=sub_sample,
115                                               bn_layer=bn_layer,)
116 
117 
118 if __name__ == '__main__':
119     import torch
120 
121     for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]:
122         img = torch.zeros(2, 3, 20)
123         net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
124         out = net(img)
125         print(out.size())
126 
127         img = torch.zeros(2, 3, 20, 20)
128         net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_, store_last_batch_nl_map=True)
129         out = net(img)
130         print(out.size())
131 
132         img = torch.randn(2, 3, 8, 20, 20)
133         net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_, store_last_batch_nl_map=True)
134         out = net(img)
135         print(out.size())
View Code

6.2 embedded Gaussian和點乘的區別

點乘程式碼:

  1 class _NonLocalBlockND(nn.Module):
  2     def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
  3         super(_NonLocalBlockND, self).__init__()
  4 
  5         assert dimension in [1, 2, 3]
  6 
  7         self.dimension = dimension
  8         self.sub_sample = sub_sample
  9 
 10         self.in_channels = in_channels
 11         self.inter_channels = inter_channels
 12 
 13         if self.inter_channels is None:
 14             self.inter_channels = in_channels // 2
 15             if self.inter_channels == 0:
 16                 self.inter_channels = 1
 17 
 18         if dimension == 3:
 19             conv_nd = nn.Conv3d
 20             max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
 21             bn = nn.BatchNorm3d
 22         elif dimension == 2:
 23             conv_nd = nn.Conv2d
 24             max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
 25             bn = nn.BatchNorm2d
 26         else:
 27             conv_nd = nn.Conv1d
 28             max_pool_layer = nn.MaxPool1d(kernel_size=(2))
 29             bn = nn.BatchNorm1d
 30 
 31         self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
 32                          kernel_size=1, stride=1, padding=0)
 33 
 34         if bn_layer:
 35             self.W = nn.Sequential(
 36                 conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
 37                         kernel_size=1, stride=1, padding=0),
 38                 bn(self.in_channels)
 39             )
 40             nn.init.constant_(self.W[1].weight, 0)
 41             nn.init.constant_(self.W[1].bias, 0)
 42         else:
 43             self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
 44                              kernel_size=1, stride=1, padding=0)
 45             nn.init.constant_(self.W.weight, 0)
 46             nn.init.constant_(self.W.bias, 0)
 47 
 48         self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
 49                              kernel_size=1, stride=1, padding=0)
 50 
 51         self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
 52                            kernel_size=1, stride=1, padding=0)
 53 
 54         if sub_sample:
 55             self.g = nn.Sequential(self.g, max_pool_layer)
 56             self.phi = nn.Sequential(self.phi, max_pool_layer)
 57 
 58     def forward(self, x, return_nl_map=False):
 59         """
 60         :param x: (b, c, t, h, w)
 61         :param return_nl_map: if True return z, nl_map, else only return z.
 62         :return:
 63         """
 64         # 令x維度B*C*(K):一維時,x為B*C*(K1);二維時,x為B*C*(K1*K2);三維時,x為B*C*(K1*K2*K3)
 65         batch_size = x.size(0)
 66 
 67         g_x = self.g(x).view(batch_size, self.inter_channels, -1)    # 通過g函式,並reshape,得到B*inter_channels*(K)矩陣
 68         g_x = g_x.permute(0, 2, 1)   # 得到B*(K)*inter_channels矩陣,和圖2中一致
 69 
 70         theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)   # 通過θ函式,並reshape,得到B*inter_channels*(K)矩陣
 71         theta_x = theta_x.permute(0, 2, 1)    # 得到B*(K)*inter_channels矩陣,和圖2中一致
 72         phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)    # 通過φ函式,並reshape,得到B*inter_channels*(K)矩陣
 73         f = torch.matmul(theta_x, phi_x)     # 得到B*(K)*(K)矩陣,和圖2中一致
 74         N = f.size(-1)   # 最後一維的維度
 75         f_div_C = f / N  # 對最後一維歸一化
 76 
 77         y = torch.matmul(f_div_C, g_x)    # 得到B*(K)*inter_channels矩陣,和圖2中一致
 78         y = y.permute(0, 2, 1).contiguous()   # 得到B*inter_channels*(K)矩陣,和圖2中一致
 79         y = y.view(batch_size, self.inter_channels, *x.size()[2:])   # 得到B*inter_channels*(K1或K1*K2或K1*K2*K3)矩陣,和圖2中一致
 80         W_y = self.W(y) # 得到B*C*(K)矩陣,和圖2中一致
 81         z = W_y + x   # 特徵圖和non local的圖相加,得到新的特徵圖,B*C*(K)
 82 
 83         if return_nl_map:
 84             return z, f_div_C  # 返回結果及歸一化的特徵
 85         return z
 86 
 87 
 88 class NONLocalBlock1D(_NonLocalBlockND):
 89     def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
 90         super(NONLocalBlock1D, self).__init__(in_channels,
 91                                               inter_channels=inter_channels,
 92                                               dimension=1, sub_sample=sub_sample,
 93                                               bn_layer=bn_layer)
 94 
 95 
 96 class NONLocalBlock2D(_NonLocalBlockND):
 97     def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
 98         super(NONLocalBlock2D, self).__init__(in_channels,
 99                                               inter_channels=inter_channels,
100                                               dimension=2, sub_sample=sub_sample,
101                                               bn_layer=bn_layer)
102 
103 
104 class NONLocalBlock3D(_NonLocalBlockND):
105     def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
106         super(NONLocalBlock3D, self).__init__(in_channels,
107                                               inter_channels=inter_channels,
108                                               dimension=3, sub_sample=sub_sample,
109                                               bn_layer=bn_layer)
110 
111 
112 if __name__ == '__main__':
113     import torch
114 
115     for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]:
116         img = torch.zeros(2, 3, 20)
117         net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
118         out = net(img)
119         print(out.size())
120 
121         img = torch.zeros(2, 3, 20, 20)
122         net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
123         out = net(img)
124         print(out.size())
125 
126         img = torch.randn(2, 3, 8, 20, 20)
127         net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
128         out = net(img)
129         print(out.size())
View Code

左側為embedded Gaussian,右側為點乘

 

6.3 embedded Gaussian和Gaussian的區別

左側為embedded Gaussian,右側為Gaussian

初始化:

forward:

6.4 embedded Gaussian和Concatenation的區別

Concatenation程式碼:

  1 class _NonLocalBlockND(nn.Module):
  2     def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
  3         super(_NonLocalBlockND, self).__init__()
  4 
  5         assert dimension in [1, 2, 3]
  6 
  7         self.dimension = dimension
  8         self.sub_sample = sub_sample
  9 
 10         self.in_channels = in_channels
 11         self.inter_channels = inter_channels
 12 
 13         if self.inter_channels is None:
 14             self.inter_channels = in_channels // 2
 15             if self.inter_channels == 0:
 16                 self.inter_channels = 1
 17 
 18         if dimension == 3:
 19             conv_nd = nn.Conv3d
 20             max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
 21             bn = nn.BatchNorm3d
 22         elif dimension == 2:
 23             conv_nd = nn.Conv2d
 24             max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
 25             bn = nn.BatchNorm2d
 26         else:
 27             conv_nd = nn.Conv1d
 28             max_pool_layer = nn.MaxPool1d(kernel_size=(2))
 29             bn = nn.BatchNorm1d
 30 
 31         self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
 32                          kernel_size=1, stride=1, padding=0)
 33 
 34         if bn_layer:
 35             self.W = nn.Sequential(
 36                 conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
 37                         kernel_size=1, stride=1, padding=0),
 38                 bn(self.in_channels)
 39             )
 40             nn.init.constant_(self.W[1].weight, 0)
 41             nn.init.constant_(self.W[1].bias, 0)
 42         else:
 43             self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
 44                              kernel_size=1, stride=1, padding=0)
 45             nn.init.constant_(self.W.weight, 0)
 46             nn.init.constant_(self.W.bias, 0)
 47 
 48         self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
 49                              kernel_size=1, stride=1, padding=0)
 50 
 51         self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
 52                            kernel_size=1, stride=1, padding=0)
 53 
 54         self.concat_project = nn.Sequential(   # 將concat後的特徵降維到1維的矩陣
 55             nn.Conv2d(self.inter_channels * 2, 1, 1, 1, 0, bias=False),
 56             nn.ReLU()
 57         )
 58 
 59         if sub_sample:
 60             self.g = nn.Sequential(self.g, max_pool_layer)
 61             self.phi = nn.Sequential(self.phi, max_pool_layer)
 62 
 63     def forward(self, x, return_nl_map=False):
 64         '''
 65         :param x: (b, c, t, h, w)
 66         :param return_nl_map: if True return z, nl_map, else only return z.
 67         :return:
 68         '''
 69         # 令x維度B*C*(K):一維時,x為B*C*(K1);二維時,x為B*C*(K1*K2);三維時,x為B*C*(K1*K2*K3)
 70         batch_size = x.size(0)
 71 
 72         g_x = self.g(x).view(batch_size, self.inter_channels, -1)   # 通過g函式,並reshape,得到B*inter_channels*(K)矩陣
 73         g_x = g_x.permute(0, 2, 1)  # 得到B*(K)*inter_channels矩陣,和圖2中一致
 74 
 75         # (b, c, N, 1)
 76         theta_x = self.theta(x).view(batch_size, self.inter_channels, -1, 1)  # 通過θ函式,並reshape,得到B*inter_channels*(K)*1矩陣
 77         # (b, c, 1, N)
 78         phi_x = self.phi(x).view(batch_size, self.inter_channels, 1, -1) # 通過φ函式,並reshape,得到B*inter_channels*1*(K)矩陣
 79 
 80         h = theta_x.size(2)  # (K)
 81         w = phi_x.size(3)  # (K)
 82         theta_x = theta_x.repeat(1, 1, 1, w)  # B*inter_channels*(K)*(K)
 83         phi_x = phi_x.repeat(1, 1, h, 1)     # B*inter_channels*(K)*(K)
 84 
 85         concat_feature = torch.cat([theta_x, phi_x], dim=1)  # B*(2*inter_channels)*(K)*(K)
 86         f = self.concat_project(concat_feature)    # B*1*(K)*(K)
 87         b, _, h, w = f.size()  # B,_,(K),(K)
 88         f = f.view(b, h, w)   # B*(K)*(K)
 89 
 90         N = f.size(-1)  # (K)
 91         f_div_C = f / N   # 最後一維歸一化,B*(K)*(K)
 92 
 93         y = torch.matmul(f_div_C, g_x)    # 得到B*(K)*inter_channels矩陣,和圖2中一致
 94         y = y.permute(0, 2, 1).contiguous()# 得到B*inter_channels*(K)矩陣,和圖2中一致
 95         y = y.view(batch_size, self.inter_channels, *x.size()[2:])  # 得到B*inter_channels*(K1或K1*K2或K1*K2*K3)矩陣,和圖2中一致
 96         W_y = self.W(y)  # 得到B*C*(K)矩陣,和圖2中一致
 97         z = W_y + x   # 特徵圖和non local的圖相加,得到新的特徵圖,B*C*(K)
 98 
 99         if return_nl_map:
100             return z, f_div_C    # 返回結果及歸一化的特徵
101         return z
102 
103 
104 class NONLocalBlock1D(_NonLocalBlockND):
105     def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
106         super(NONLocalBlock1D, self).__init__(in_channels,
107                                               inter_channels=inter_channels,
108                                               dimension=1, sub_sample=sub_sample,
109                                               bn_layer=bn_layer)
110 
111 
112 class NONLocalBlock2D(_NonLocalBlockND):
113     def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
114         super(NONLocalBlock2D, self).__init__(in_channels,
115                                               inter_channels=inter_channels,
116                                               dimension=2, sub_sample=sub_sample,
117                                               bn_layer=bn_layer)
118 
119 
120 class NONLocalBlock3D(_NonLocalBlockND):
121     def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True,):
122         super(NONLocalBlock3D, self).__init__(in_channels,
123                                               inter_channels=inter_channels,
124                                               dimension=3, sub_sample=sub_sample,
125                                               bn_layer=bn_layer)
126 
127 
128 if __name__ == '__main__':
129     import torch
130 
131     for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]:
132         img = torch.zeros(2, 3, 20)
133         net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
134         out = net(img)
135         print(out.size())
136 
137         img = torch.zeros(2, 3, 20, 20)
138         net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
139         out = net(img)
140         print(out.size())
141 
142         img = torch.randn(2, 3, 8, 20, 20)
143         net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
144         out = net(img)
145         print(out.size())
View Code

左側為embedded Gaussian,右側為Concatenation

初始化:

forward: