Non-local Neural Networks
阿新 • • 發佈:2020-09-14
Non-local Neural Networks
一. 論文簡介
影象上(擴大感受野),視訊序列(臨近幾幀畫素不同的問題聯合),從區域性資訊到全域性資訊
主要做的貢獻如下(可能之前有人已提出):
- 解決區域性感受野,設計一個Block
二. 模組詳解
2.1 Local和Non-Local
Local和Non-Local都是針對感受野來說的,3*3卷積就代表當前畫素的感受野範圍為9(8也可以,就是那個意思)
插曲:
- 看到這篇論文,真的有種相見恨晚的感覺,之前看到shuffleNet,通道之間打亂(按一定規則排序)可以增加資訊量,獲得更好的結果。那麼為什麼不能把feature打亂呢?\((B、C、W、H)\)
,咱們一一分析:- B在取樣的時候已經打亂了,而且多少也可以設定。理論上,製藥模型足夠魯棒,B越大越好。
- C的操作有很多,直接卷積就是對C的擴充套件,打亂是ShuffleNet的做法,不同權重是Attention的做法,大部分論文都是對C的操作,比如ResNet就是對不同通道相加.....
- W、H的操作很少,最直接FC操作,這個操作效果很好,但是計算量太大。現在迴歸都不使用FC,使用1*1卷積+Reshape操作進行代替,比如人臉關鍵點(小網路)。
- 我本來的想法是將feature按block進行重新組合,然後卷積操作就可以獲得不同區域的資訊。
註釋:
- 使用多個卷積串聯可以增大感受野,但是在計算的過程中會丟失資訊
,所以串聯得到的全域性資訊是不足的(做什麼都會丟失,多少而已)。- 使用SE模組可以獲得全域性資訊,但是完全沒有FC強大。
- 有沒有比FC計算量小,而且資訊量獲得和FC差不多的?
下面這幅圖是論文的核心,某一個點的預測,需要獲得不同位置的輔助,同時輔助的強度需要一個W權重控制。
2.2 具體實現
2.2.1 理論部分
看下面公式 \((1)\),\(x\) 表示輸入特徵,\(x_j\) 當前特徵,\(x_i\) 周圍特徵,\(f\) 表示相關函式(變換函式、\(x_i 、 x_j\) 關係函式) \(C\) 表示歸一化值(一般softmax即可), \(g\) 表示當前特徵變換函式。
其實很簡單的一個函式,\(f\)
當做相關性函式(具體實現後面說),\(g\) 直接當做一個卷積,那麼兩者相乘就可以得到全域性資訊的 \(x\)。
整片文章都在介紹 \(f\) 這個二元函式的生成方式,有Gaussian、Embedded Gaussian、.....具體不用細看,因為實現比較麻煩,能用卷積的肯定不用其他的。
下面公式\((4)\) 代表高斯函式,公式\((5)\) 代表 \(g\) 函式:
如果還不懂上面的公式,直接看程式碼就恍然大悟
2.2.2 具體實現
程式碼的實現完全是按照論文敘述,整體結構如下圖所示,其中下采樣直接在 \(\phi、g\) 後面加maxpooling即可。
import torch
from torch import nn
from torch.nn import functional as F
class _NonLocalBlockND(nn.Module):
def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
super(_NonLocalBlockND, self).__init__()
assert dimension in [1, 2, 3]
self.dimension = dimension
self.sub_sample = sub_sample
self.in_channels = in_channels
self.inter_channels = inter_channels
if self.inter_channels is None:
self.inter_channels = in_channels // 2
if self.inter_channels == 0:
self.inter_channels = 1
if dimension == 3:
conv_nd = nn.Conv3d
max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
bn = nn.BatchNorm3d
elif dimension == 2:
conv_nd = nn.Conv2d
max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
bn = nn.BatchNorm2d
else:
conv_nd = nn.Conv1d
max_pool_layer = nn.MaxPool1d(kernel_size=(2))
bn = nn.BatchNorm1d
self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
if bn_layer:
self.W = nn.Sequential(
conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0),
bn(self.in_channels)
)
nn.init.constant_(self.W[1].weight, 0)
nn.init.constant_(self.W[1].bias, 0)
else:
self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0)
nn.init.constant_(self.W.weight, 0)
nn.init.constant_(self.W.bias, 0)
self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
if sub_sample:
self.g = nn.Sequential(self.g, max_pool_layer)
self.phi = nn.Sequential(self.phi, max_pool_layer)
def forward(self, x):
'''
:param x: (b, c, t, h, w)
:return:
'''
batch_size = x.size(0)
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
f = torch.matmul(theta_x, phi_x)
f_div_C = F.softmax(f, dim=-1)
y = torch.matmul(f_div_C, g_x)
y = y.permute(0, 2, 1).contiguous()
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
W_y = self.W(y)
z = W_y + x
return z
class NONLocalBlock1D(_NonLocalBlockND):
def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
super(NONLocalBlock1D, self).__init__(in_channels,
inter_channels=inter_channels,
dimension=1, sub_sample=sub_sample,
bn_layer=bn_layer)
class NONLocalBlock2D(_NonLocalBlockND):
def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
super(NONLocalBlock2D, self).__init__(in_channels,
inter_channels=inter_channels,
dimension=2, sub_sample=sub_sample,
bn_layer=bn_layer)
class NONLocalBlock3D(_NonLocalBlockND):
def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
super(NONLocalBlock3D, self).__init__(in_channels,
inter_channels=inter_channels,
dimension=3, sub_sample=sub_sample,
bn_layer=bn_layer)
if __name__ == '__main__':
import torch
for (sub_sample, bn_layer) in [(True, True), (False, False), (True, False), (False, True)]:
'''
img = torch.zeros(2, 3, 20)
net = NONLocalBlock1D(3, sub_sample=sub_sample, bn_layer=bn_layer)
out = net(img)
print(out.size())
'''
img = torch.zeros(2, 3, 20, 20)
net = NONLocalBlock2D(3, sub_sample=sub_sample, bn_layer=bn_layer)
out = net(img)
print(out.size())
img = torch.randn(2, 3, 8, 20, 20)
net = NONLocalBlock3D(3, sub_sample=sub_sample, bn_layer=bn_layer)
out = net(img)
print(out.size())