1. 程式人生 > 實用技巧 >Non-local Neural Networks

Non-local Neural Networks

Non-local Neural Networks

一. 論文簡介

影象上(擴大感受野),視訊序列(臨近幾幀畫素不同的問題聯合),從區域性資訊到全域性資訊

主要做的貢獻如下(可能之前有人已提出):

  1. 解決區域性感受野,設計一個Block

二. 模組詳解

2.1 Local和Non-Local

Local和Non-Local都是針對感受野來說的,3*3卷積就代表當前畫素的感受野範圍為9(8也可以,就是那個意思)

插曲:

  • 看到這篇論文,真的有種相見恨晚的感覺,之前看到shuffleNet,通道之間打亂(按一定規則排序)可以增加資訊量,獲得更好的結果。那麼為什麼不能把feature打亂呢?\((B、C、W、H)\)
    ,咱們一一分析:
  • B在取樣的時候已經打亂了,而且多少也可以設定。理論上,製藥模型足夠魯棒,B越大越好。
  • C的操作有很多,直接卷積就是對C的擴充套件,打亂是ShuffleNet的做法,不同權重是Attention的做法,大部分論文都是對C的操作,比如ResNet就是對不同通道相加.....
  • W、H的操作很少,最直接FC操作,這個操作效果很好,但是計算量太大。現在迴歸都不使用FC,使用1*1卷積+Reshape操作進行代替,比如人臉關鍵點(小網路)。
  • 我本來的想法是將feature按block進行重新組合,然後卷積操作就可以獲得不同區域的資訊。

註釋:

  • 使用多個卷積串聯可以增大感受野,但是在計算的過程中會丟失資訊
    ,所以串聯得到的全域性資訊是不足的(做什麼都會丟失,多少而已)。
  • 使用SE模組可以獲得全域性資訊,但是完全沒有FC強大。
  • 有沒有比FC計算量小,而且資訊量獲得和FC差不多的?

下面這幅圖是論文的核心,某一個點的預測,需要獲得不同位置的輔助,同時輔助的強度需要一個W權重控制。


2.2 具體實現

2.2.1 理論部分

看下面公式 \((1)\)\(x\) 表示輸入特徵,\(x_j\) 當前特徵,\(x_i\) 周圍特徵,\(f\) 表示相關函式(變換函式、\(x_i 、 x_j\) 關係函式) \(C\) 表示歸一化值(一般softmax即可), \(g\) 表示當前特徵變換函式。

其實很簡單的一個函式,\(f\)

當做相關性函式(具體實現後面說),\(g\) 直接當做一個卷積,那麼兩者相乘就可以得到全域性資訊的 \(x\)

整片文章都在介紹 \(f\) 這個二元函式的生成方式,有Gaussian、Embedded Gaussian、.....具體不用細看,因為實現比較麻煩,能用卷積的肯定不用其他的。

下面公式\((4)\) 代表高斯函式,公式\((5)\) 代表 \(g\) 函式:

如果還不懂上面的公式,直接看程式碼就恍然大悟

2.2.2 具體實現

程式碼的實現完全是按照論文敘述,整體結構如下圖所示,其中下采樣直接在 \(\phi、g\) 後面加maxpooling即可。

import torch
from torch import nn
from torch.nn import functional as F


class _NonLocalBlockND(nn.Module):
    def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
        super(_NonLocalBlockND, self).__init__()

        assert dimension in [1, 2, 3]

        self.dimension = dimension
        self.sub_sample = sub_sample

        self.in_channels = in_channels
        self.inter_channels = inter_channels

        if self.inter_channels is None:
            self.inter_channels = in_channels // 2
            if self.inter_channels == 0:
                self.inter_channels = 1

        if dimension == 3:
            conv_nd = nn.Conv3d
            max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
            bn = nn.BatchNorm3d
        elif dimension == 2:
            conv_nd = nn.Conv2d
            max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
            bn = nn.BatchNorm2d
        else:
            conv_nd = nn.Conv1d
            max_pool_layer = nn.MaxPool1d(kernel_size=(2))
            bn = nn.BatchNorm1d

        self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                         kernel_size=1, stride=1, padding=0)

        if bn_layer:
            self.W = nn.Sequential(
                conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
                        kernel_size=1, stride=1, padding=0),
                bn(self.in_channels)
            )
            nn.init.constant_(self.W[1].weight, 0)
            nn.init.constant_(self.W[1].bias, 0)
        else:
            self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
                             kernel_size=1, stride=1, padding=0)
            nn.init.constant_(self.W.weight, 0)
            nn.init.constant_(self.W.bias, 0)

        self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                             kernel_size=1, stride=1, padding=0)
        self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                           kernel_size=1, stride=1, padding=0)

        if sub_sample:
            self.g = nn.Sequential(self.g, max_pool_layer)
            self.phi = nn.Sequential(self.phi, max_pool_layer)

    def forward(self, x):
        '''
        :param x: (b, c, t, h, w)
        :return:
        '''

        batch_size = x.size(0)

        g_x = self.g(x).view(batch_size, self.inter_channels, -1)
        g_x = g_x.permute(0, 2, 1)

        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
        theta_x = theta_x.permute(0, 2, 1)
        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
        f = torch.matmul(theta_x, phi_x)
        f_div_C = F.softmax(f, dim=-1)

        y = torch.matmul(f_div_C, g_x)
        y = y.permute(0, 2, 1).contiguous()
        y = y.view(batch_size, self.inter_channels, *x.size()[2:])
        W_y = self.W(y)
        z = W_y + x

        return z


class NONLocalBlock1D(_NonLocalBlockND):
    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
        super(NONLocalBlock1D, self).__init__(in_channels,
                                              inter_channels=inter_channels,
                                              dimension=1, sub_sample=sub_sample,
                                              bn_layer=bn_layer)


class NONLocalBlock2D(_NonLocalBlockND):
    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
        super(NONLocalBlock2D, self).__init__(in_channels,
                                              inter_channels=inter_channels,
                                              dimension=2, sub_sample=sub_sample,
                                              bn_layer=bn_layer)


class NONLocalBlock3D(_NonLocalBlockND):
    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
        super(NONLocalBlock3D, self).__init__(in_channels,
                                              inter_channels=inter_channels,
                                              dimension=3, sub_sample=sub_sample,
                                              bn_layer=bn_layer)


if __name__ == '__main__':
    import torch

    for (sub_sample, bn_layer) in [(True, True), (False, False), (True, False), (False, True)]:
        '''
        img = torch.zeros(2, 3, 20)
        net = NONLocalBlock1D(3, sub_sample=sub_sample, bn_layer=bn_layer)
        out = net(img)
        print(out.size())
        '''
        img = torch.zeros(2, 3, 20, 20)
        net = NONLocalBlock2D(3, sub_sample=sub_sample, bn_layer=bn_layer)
        out = net(img)
        print(out.size())

        img = torch.randn(2, 3, 8, 20, 20)
        net = NONLocalBlock3D(3, sub_sample=sub_sample, bn_layer=bn_layer)
        out = net(img)
        print(out.size())