(原)Non-local Neural Networks
轉載請註明出處:
https://www.cnblogs.com/darkknightzh/p/12592351.html
論文:
https://arxiv.org/abs/1711.07971
第三方pytorch程式碼:
https://github.com/AlexHex7/Non-local_pytorch
1. non local操作
該論文定義了通用了non local操作:
${{\mathbf{y}}_{i}}=\frac{1}{C(\mathbf{x})}\sum\limits_{\forall j}{f({{\mathbf{x}}_{i}},{{\mathbf{x}}_{j}})g({{\mathbf{x}}_{j}})}$
其中i為需要計算響應的輸出位置的索引,j為所有的位置。x為輸入訊號(影象,序列,視訊等,通常為這些訊號的特徵),y為個x相同尺寸的輸出訊號。f為pairwise的函式,f計算當前i和所有j之間的關係,並得到一個標量。一元函式g計算輸入訊號在位置j的表徵。(這段翻譯起來怪怪的)。C(x)為歸一化係數,用於歸一化f和g的結果。
2. non local和其他操作的區別
① non local考慮到了所有的位置j。卷積操作僅考慮了當前位置的一個鄰域(如核為3的一維卷積僅考慮了i-1<=j<=i+1);迴圈操作通常只考慮當前和上一個時間,j=i或j=i-1.
② non local根據不同位置的關係計算響應,fc使用學習到的權重。換言之,fc中,${{\mathbf{x}}_{i}}$和${{\mathbf{x}}_{j}}$之間不是函式關係,而non local中則是函式關係。
③ non local支援輸入不同尺寸,並且保持輸出和輸入相同的尺寸;fc則需要輸入和輸出均為固定的尺寸,並且丟失了位置關係。
④ non local可以用在網路的早期部分,fc通常用在網路最後。
3. f和g的形式
3.1 g的形式
為簡單起見,只考慮g為線性形式,$g({{\mathbf{x}}_{j}})\text{=}{{W}_{g}}{{\mathbf{x}}_{j}}$,${{W}_{g}}$為需要學習的權重向量,在空域可以使用1*1conv實現,在空間時間域(如時間序列的影象)可以通過1*1*1的卷積實現。
3.2 f為gaussian
$f({{\mathbf{x}}_{i}},{{\mathbf{x}}_{j}})\text{=}{{e}^{\mathbf{x}_{i}^{T}{{\mathbf{x}}_{j}}}}$
其中$\mathbf{x}_{i}^{T}{{\mathbf{x}}_{j}}$為點乘,因為點乘在深度學習平臺中更易實現(歐式距離也可以)。此時歸一化係數$C(\mathbf{x})=\sum\nolimits_{\forall j}{f({{\mathbf{x}}_{i}},{{\mathbf{x}}_{j}})}$
3.3 f為embedded Gaussian
$f({{\mathbf{x}}_{i}},{{\mathbf{x}}_{j}})\text{=}{{e}^{\theta {{({{\mathbf{x}}_{i}})}^{T}}\phi ({{\mathbf{x}}_{j}})}}$
其中$\theta ({{\mathbf{x}}_{i}})\text{=}{{W}_{\theta }}{{\mathbf{x}}_{i}}$,$\phi ({{\mathbf{x}}_{j}})\text{=}{{W}_{\phi }}{{\mathbf{x}}_{j}}$,此時$C(\mathbf{x})=\sum\nolimits_{\forall j}{f({{\mathbf{x}}_{i}},{{\mathbf{x}}_{j}})}$
self attention模組和non local的關係:可以認為self attention為embedded Gaussian的特殊形式,如給定i,$\frac{1}{C(\mathbf{x})}f({{\mathbf{x}}_{i}},{{\mathbf{x}}_{j}})$沿著j維度變成了計算softmax。此時$\mathbf{y}=softmax({{\mathbf{x}}^{T}}W_{\theta }^{T}{{W}_{\phi }}\mathbf{x})g(\mathbf{x})$,即為self attention的形式。
3.4 點乘
f可以定義為點乘的相似度(此處使用embedded的形式):
$f({{\mathbf{x}}_{i}},{{\mathbf{x}}_{j}})\text{=}\theta {{({{\mathbf{x}}_{i}})}^{T}}\phi ({{\mathbf{x}}_{j}})$
此時,歸一化係數$C(\mathbf{x})=N$,N為x中所有位置的數量,而不是f的sum,這樣可以簡化梯度的計算。
點乘和embedded Gaussian的區別是是否使用了作為啟用函式的softmax。
3.5 Concatenation
$f({{\mathbf{x}}_{i}},{{\mathbf{x}}_{j}})\text{=ReLU(w}_{f}^{T}[\theta ({{\mathbf{x}}_{i}}),\phi ({{\mathbf{x}}_{j}})]\text{)}$
其中$[\cdot \cdot ]$代表concatenation,即拼接。${{w}_{f}}$為權重向量,用於將拼接後的向量對映到一個標量。$C(\mathbf{x})=N$
4. Non local block
將之前公式的non local操作擴充套件成non local block,可以嵌入到目前的網路結構中,如下:
${{\mathbf{z}}_{i}}={{W}_{z}}{{\mathbf{y}}_{i}}+{{\mathbf{x}}_{i}}$
其中${{\mathbf{y}}_{i}}=\frac{1}{C(\mathbf{x})}\sum\limits_{\forall j}{f({{\mathbf{x}}_{i}},{{\mathbf{x}}_{j}})g({{\mathbf{x}}_{j}})}$,$+{{\mathbf{x}}_{i}}$代表殘差連線。殘差連線方便將non local block嵌入到之前與訓練的模型中,避免打亂其初始行為(如將${{W}_{z}}$初始化為0)。
non local block如下圖所示。3.2,3.3,3.4中的pairwise計算對應於下圖中的矩陣乘法。在網路後面的特徵圖上,pairwise計算量比較小。
說明:
1. 若為影象,則使用1*1conv,且圖中無T;若為視訊,則使用1*1*1conv,且圖中有T。
2. 圖中softmax指對該矩陣每行計算softmax。
5. 降低計算量
5.1 降低x的通道數量
將${{W}_{g}}$,${{W}_{\theta }}$,${{W}_{\phi }}$降低為x的通道數量的一半,可以降低計算量。
5.2 對x下采樣。
對x下采樣,可以進一步降低計算量。
此時,1中的共識修改為${{\mathbf{y}}_{i}}=\frac{1}{C(\mathbf{\hat{x}})}\sum\limits_{\forall j}{f({{\mathbf{x}}_{i}},{{{\mathbf{\hat{x}}}}_{j}})g({{{\mathbf{\hat{x}}}}_{j}})}$,其中$\mathbf{\hat{x}}$為對x進行下采樣後的輸入(如pooling)。這種方式可以降低pariwsie計算到原來的1/4,一方面不影響non local的行為,另一方面,使得計算更加稀疏。可以通過在上圖中$\phi $和$g$後面加一個max pooling來實現。
6. 程式碼:
6.1 embedded_gaussian
1 class _NonLocalBlockND(nn.Module): 2 def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 3 """ 4 :param in_channels: 5 :param inter_channels: 6 :param dimension: 7 :param sub_sample: 8 :param bn_layer: 9 """ 10 11 super(_NonLocalBlockND, self).__init__() 12 13 assert dimension in [1, 2, 3] 14 15 self.dimension = dimension 16 self.sub_sample = sub_sample 17 18 self.in_channels = in_channels 19 self.inter_channels = inter_channels 20 21 if self.inter_channels is None: 22 self.inter_channels = in_channels // 2 23 if self.inter_channels == 0: 24 self.inter_channels = 1 25 26 if dimension == 3: 27 conv_nd = nn.Conv3d 28 max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 29 bn = nn.BatchNorm3d 30 elif dimension == 2: 31 conv_nd = nn.Conv2d 32 max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 33 bn = nn.BatchNorm2d 34 else: 35 conv_nd = nn.Conv1d 36 max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 37 bn = nn.BatchNorm1d 38 39 self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 40 kernel_size=1, stride=1, padding=0) # g函式,1*1conv,用於降維 41 42 if bn_layer: 43 self.W = nn.Sequential( # 1*1conv,用於圖2中變換到原始維度 44 conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 45 kernel_size=1, stride=1, padding=0), 46 bn(self.in_channels) 47 ) 48 nn.init.constant_(self.W[1].weight, 0) 49 nn.init.constant_(self.W[1].bias, 0) 50 else: 51 self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 52 kernel_size=1, stride=1, padding=0) # 1*1conv,用於圖2中變換到原始維度 53 nn.init.constant_(self.W.weight, 0) 54 nn.init.constant_(self.W.bias, 0) 55 56 self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 57 kernel_size=1, stride=1, padding=0) # θ函式,1*1conv,用於降維 58 self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 59 kernel_size=1, stride=1, padding=0) # φ函式,1*1conv,用於降維 60 61 if sub_sample: 62 self.g = nn.Sequential(self.g, max_pool_layer) 63 self.phi = nn.Sequential(self.phi, max_pool_layer) 64 65 def forward(self, x, return_nl_map=False): 66 """ 67 :param x: (b, c, t, h, w) 68 :param return_nl_map: if True return z, nl_map, else only return z. 69 :return: 70 """ 71 # 令x維度B*C*(K):一維時,x為B*C*(K1);二維時,x為B*C*(K1*K2);三維時,x為B*C*(K1*K2*K3) 72 batch_size = x.size(0) # batchsize 73 74 g_x = self.g(x).view(batch_size, self.inter_channels, -1) # 通過g函式,並reshape,得到B*inter_channels*(K)矩陣 75 g_x = g_x.permute(0, 2, 1) # 得到B*(K)*inter_channels矩陣,和圖2中一致 76 77 theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) # 通過θ函式,並reshape,得到B*inter_channels*(K)矩陣 78 theta_x = theta_x.permute(0, 2, 1) # 得到B*(K)*inter_channels矩陣,和圖2中一致 79 phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) # 通過φ函式,並reshape,得到B*inter_channels*(K)矩陣 80 f = torch.matmul(theta_x, phi_x) # 得到B*(K)*(K)矩陣,和圖2中一致 81 f_div_C = F.softmax(f, dim=-1) # 通過softmax,對最後一維歸一化,得到歸一化的特徵,即概率,B*(K)*(K) 82 83 y = torch.matmul(f_div_C, g_x) # 得到B*(K)*inter_channels矩陣,和圖2中一致 84 y = y.permute(0, 2, 1).contiguous() # 得到B*inter_channels*(K)矩陣,和圖2中一致 85 y = y.view(batch_size, self.inter_channels, *x.size()[2:]) # 得到B*inter_channels*(K1或K1*K2或K1*K2*K3)矩陣,和圖2中一致 86 W_y = self.W(y) # 得到B*C*(K)矩陣,和圖2中一致 87 z = W_y + x # 特徵圖和non local的圖相加,得到新的特徵圖,B*C*(K) 88 89 if return_nl_map: 90 return z, f_div_C # 返回結果及歸一化的特徵 91 return z 92 93 94 class NONLocalBlock1D(_NonLocalBlockND): 95 def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 96 super(NONLocalBlock1D, self).__init__(in_channels, 97 inter_channels=inter_channels, 98 dimension=1, sub_sample=sub_sample, 99 bn_layer=bn_layer) 100 101 102 class NONLocalBlock2D(_NonLocalBlockND): 103 def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 104 super(NONLocalBlock2D, self).__init__(in_channels, 105 inter_channels=inter_channels, 106 dimension=2, sub_sample=sub_sample, 107 bn_layer=bn_layer,) 108 109 110 class NONLocalBlock3D(_NonLocalBlockND): 111 def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 112 super(NONLocalBlock3D, self).__init__(in_channels, 113 inter_channels=inter_channels, 114 dimension=3, sub_sample=sub_sample, 115 bn_layer=bn_layer,) 116 117 118 if __name__ == '__main__': 119 import torch 120 121 for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]: 122 img = torch.zeros(2, 3, 20) 123 net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 124 out = net(img) 125 print(out.size()) 126 127 img = torch.zeros(2, 3, 20, 20) 128 net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_, store_last_batch_nl_map=True) 129 out = net(img) 130 print(out.size()) 131 132 img = torch.randn(2, 3, 8, 20, 20) 133 net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_, store_last_batch_nl_map=True) 134 out = net(img) 135 print(out.size())View Code
6.2 embedded Gaussian和點乘的區別
點乘程式碼:
1 class _NonLocalBlockND(nn.Module): 2 def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 3 super(_NonLocalBlockND, self).__init__() 4 5 assert dimension in [1, 2, 3] 6 7 self.dimension = dimension 8 self.sub_sample = sub_sample 9 10 self.in_channels = in_channels 11 self.inter_channels = inter_channels 12 13 if self.inter_channels is None: 14 self.inter_channels = in_channels // 2 15 if self.inter_channels == 0: 16 self.inter_channels = 1 17 18 if dimension == 3: 19 conv_nd = nn.Conv3d 20 max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 21 bn = nn.BatchNorm3d 22 elif dimension == 2: 23 conv_nd = nn.Conv2d 24 max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 25 bn = nn.BatchNorm2d 26 else: 27 conv_nd = nn.Conv1d 28 max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 29 bn = nn.BatchNorm1d 30 31 self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 32 kernel_size=1, stride=1, padding=0) 33 34 if bn_layer: 35 self.W = nn.Sequential( 36 conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 37 kernel_size=1, stride=1, padding=0), 38 bn(self.in_channels) 39 ) 40 nn.init.constant_(self.W[1].weight, 0) 41 nn.init.constant_(self.W[1].bias, 0) 42 else: 43 self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 44 kernel_size=1, stride=1, padding=0) 45 nn.init.constant_(self.W.weight, 0) 46 nn.init.constant_(self.W.bias, 0) 47 48 self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 49 kernel_size=1, stride=1, padding=0) 50 51 self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 52 kernel_size=1, stride=1, padding=0) 53 54 if sub_sample: 55 self.g = nn.Sequential(self.g, max_pool_layer) 56 self.phi = nn.Sequential(self.phi, max_pool_layer) 57 58 def forward(self, x, return_nl_map=False): 59 """ 60 :param x: (b, c, t, h, w) 61 :param return_nl_map: if True return z, nl_map, else only return z. 62 :return: 63 """ 64 # 令x維度B*C*(K):一維時,x為B*C*(K1);二維時,x為B*C*(K1*K2);三維時,x為B*C*(K1*K2*K3) 65 batch_size = x.size(0) 66 67 g_x = self.g(x).view(batch_size, self.inter_channels, -1) # 通過g函式,並reshape,得到B*inter_channels*(K)矩陣 68 g_x = g_x.permute(0, 2, 1) # 得到B*(K)*inter_channels矩陣,和圖2中一致 69 70 theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) # 通過θ函式,並reshape,得到B*inter_channels*(K)矩陣 71 theta_x = theta_x.permute(0, 2, 1) # 得到B*(K)*inter_channels矩陣,和圖2中一致 72 phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) # 通過φ函式,並reshape,得到B*inter_channels*(K)矩陣 73 f = torch.matmul(theta_x, phi_x) # 得到B*(K)*(K)矩陣,和圖2中一致 74 N = f.size(-1) # 最後一維的維度 75 f_div_C = f / N # 對最後一維歸一化 76 77 y = torch.matmul(f_div_C, g_x) # 得到B*(K)*inter_channels矩陣,和圖2中一致 78 y = y.permute(0, 2, 1).contiguous() # 得到B*inter_channels*(K)矩陣,和圖2中一致 79 y = y.view(batch_size, self.inter_channels, *x.size()[2:]) # 得到B*inter_channels*(K1或K1*K2或K1*K2*K3)矩陣,和圖2中一致 80 W_y = self.W(y) # 得到B*C*(K)矩陣,和圖2中一致 81 z = W_y + x # 特徵圖和non local的圖相加,得到新的特徵圖,B*C*(K) 82 83 if return_nl_map: 84 return z, f_div_C # 返回結果及歸一化的特徵 85 return z 86 87 88 class NONLocalBlock1D(_NonLocalBlockND): 89 def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 90 super(NONLocalBlock1D, self).__init__(in_channels, 91 inter_channels=inter_channels, 92 dimension=1, sub_sample=sub_sample, 93 bn_layer=bn_layer) 94 95 96 class NONLocalBlock2D(_NonLocalBlockND): 97 def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 98 super(NONLocalBlock2D, self).__init__(in_channels, 99 inter_channels=inter_channels, 100 dimension=2, sub_sample=sub_sample, 101 bn_layer=bn_layer) 102 103 104 class NONLocalBlock3D(_NonLocalBlockND): 105 def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 106 super(NONLocalBlock3D, self).__init__(in_channels, 107 inter_channels=inter_channels, 108 dimension=3, sub_sample=sub_sample, 109 bn_layer=bn_layer) 110 111 112 if __name__ == '__main__': 113 import torch 114 115 for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]: 116 img = torch.zeros(2, 3, 20) 117 net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 118 out = net(img) 119 print(out.size()) 120 121 img = torch.zeros(2, 3, 20, 20) 122 net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 123 out = net(img) 124 print(out.size()) 125 126 img = torch.randn(2, 3, 8, 20, 20) 127 net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 128 out = net(img) 129 print(out.size())View Code
左側為embedded Gaussian,右側為點乘
6.3 embedded Gaussian和Gaussian的區別
左側為embedded Gaussian,右側為Gaussian
初始化:
forward:
6.4 embedded Gaussian和Concatenation的區別
Concatenation程式碼:
1 class _NonLocalBlockND(nn.Module): 2 def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 3 super(_NonLocalBlockND, self).__init__() 4 5 assert dimension in [1, 2, 3] 6 7 self.dimension = dimension 8 self.sub_sample = sub_sample 9 10 self.in_channels = in_channels 11 self.inter_channels = inter_channels 12 13 if self.inter_channels is None: 14 self.inter_channels = in_channels // 2 15 if self.inter_channels == 0: 16 self.inter_channels = 1 17 18 if dimension == 3: 19 conv_nd = nn.Conv3d 20 max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 21 bn = nn.BatchNorm3d 22 elif dimension == 2: 23 conv_nd = nn.Conv2d 24 max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 25 bn = nn.BatchNorm2d 26 else: 27 conv_nd = nn.Conv1d 28 max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 29 bn = nn.BatchNorm1d 30 31 self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 32 kernel_size=1, stride=1, padding=0) 33 34 if bn_layer: 35 self.W = nn.Sequential( 36 conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 37 kernel_size=1, stride=1, padding=0), 38 bn(self.in_channels) 39 ) 40 nn.init.constant_(self.W[1].weight, 0) 41 nn.init.constant_(self.W[1].bias, 0) 42 else: 43 self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 44 kernel_size=1, stride=1, padding=0) 45 nn.init.constant_(self.W.weight, 0) 46 nn.init.constant_(self.W.bias, 0) 47 48 self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 49 kernel_size=1, stride=1, padding=0) 50 51 self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 52 kernel_size=1, stride=1, padding=0) 53 54 self.concat_project = nn.Sequential( # 將concat後的特徵降維到1維的矩陣 55 nn.Conv2d(self.inter_channels * 2, 1, 1, 1, 0, bias=False), 56 nn.ReLU() 57 ) 58 59 if sub_sample: 60 self.g = nn.Sequential(self.g, max_pool_layer) 61 self.phi = nn.Sequential(self.phi, max_pool_layer) 62 63 def forward(self, x, return_nl_map=False): 64 ''' 65 :param x: (b, c, t, h, w) 66 :param return_nl_map: if True return z, nl_map, else only return z. 67 :return: 68 ''' 69 # 令x維度B*C*(K):一維時,x為B*C*(K1);二維時,x為B*C*(K1*K2);三維時,x為B*C*(K1*K2*K3) 70 batch_size = x.size(0) 71 72 g_x = self.g(x).view(batch_size, self.inter_channels, -1) # 通過g函式,並reshape,得到B*inter_channels*(K)矩陣 73 g_x = g_x.permute(0, 2, 1) # 得到B*(K)*inter_channels矩陣,和圖2中一致 74 75 # (b, c, N, 1) 76 theta_x = self.theta(x).view(batch_size, self.inter_channels, -1, 1) # 通過θ函式,並reshape,得到B*inter_channels*(K)*1矩陣 77 # (b, c, 1, N) 78 phi_x = self.phi(x).view(batch_size, self.inter_channels, 1, -1) # 通過φ函式,並reshape,得到B*inter_channels*1*(K)矩陣 79 80 h = theta_x.size(2) # (K) 81 w = phi_x.size(3) # (K) 82 theta_x = theta_x.repeat(1, 1, 1, w) # B*inter_channels*(K)*(K) 83 phi_x = phi_x.repeat(1, 1, h, 1) # B*inter_channels*(K)*(K) 84 85 concat_feature = torch.cat([theta_x, phi_x], dim=1) # B*(2*inter_channels)*(K)*(K) 86 f = self.concat_project(concat_feature) # B*1*(K)*(K) 87 b, _, h, w = f.size() # B,_,(K),(K) 88 f = f.view(b, h, w) # B*(K)*(K) 89 90 N = f.size(-1) # (K) 91 f_div_C = f / N # 最後一維歸一化,B*(K)*(K) 92 93 y = torch.matmul(f_div_C, g_x) # 得到B*(K)*inter_channels矩陣,和圖2中一致 94 y = y.permute(0, 2, 1).contiguous()# 得到B*inter_channels*(K)矩陣,和圖2中一致 95 y = y.view(batch_size, self.inter_channels, *x.size()[2:]) # 得到B*inter_channels*(K1或K1*K2或K1*K2*K3)矩陣,和圖2中一致 96 W_y = self.W(y) # 得到B*C*(K)矩陣,和圖2中一致 97 z = W_y + x # 特徵圖和non local的圖相加,得到新的特徵圖,B*C*(K) 98 99 if return_nl_map: 100 return z, f_div_C # 返回結果及歸一化的特徵 101 return z 102 103 104 class NONLocalBlock1D(_NonLocalBlockND): 105 def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 106 super(NONLocalBlock1D, self).__init__(in_channels, 107 inter_channels=inter_channels, 108 dimension=1, sub_sample=sub_sample, 109 bn_layer=bn_layer) 110 111 112 class NONLocalBlock2D(_NonLocalBlockND): 113 def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 114 super(NONLocalBlock2D, self).__init__(in_channels, 115 inter_channels=inter_channels, 116 dimension=2, sub_sample=sub_sample, 117 bn_layer=bn_layer) 118 119 120 class NONLocalBlock3D(_NonLocalBlockND): 121 def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True,): 122 super(NONLocalBlock3D, self).__init__(in_channels, 123 inter_channels=inter_channels, 124 dimension=3, sub_sample=sub_sample, 125 bn_layer=bn_layer) 126 127 128 if __name__ == '__main__': 129 import torch 130 131 for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]: 132 img = torch.zeros(2, 3, 20) 133 net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 134 out = net(img) 135 print(out.size()) 136 137 img = torch.zeros(2, 3, 20, 20) 138 net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 139 out = net(img) 140 print(out.size()) 141 142 img = torch.randn(2, 3, 8, 20, 20) 143 net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 144 out = net(img) 145 print(out.size())View Code
左側為embedded Gaussian,右側為Concatenation
初始化:
forward: