1. 程式人生 > 實用技巧 >Disentangled Non-Local Neural Networks

Disentangled Non-Local Neural Networks

目錄

Disentangled Non-Local Neural Networks

一. 論文簡介

理論(部分感覺不是很合理,不懂大佬思維)和實踐相結合的論文,感覺很不錯,第一次讀很難讀懂。

解決區域性感受野的問題,是上一篇論文 的擴充套件

主要做的貢獻如下(可能之前有人已提出):

  1. 解決區域性感受野,設計一個Block

二. 模組詳解

2.1 論文思路簡介

全部基於論文的內容進行改進,下述將論文A進行代替:

論文A主要是表達一個函式\(f(x_i,x_j)*f(x_j)\) ,表示當前畫素的表達需要依靠周圍畫素,前者表示周圍畫素的權重,後者表示當前畫素進行的處理(你也可以直接化簡函式\(f(x_i,x_j)*x_i\)

論文A中的缺點是 \(f(x_i,x_j)\)(當前畫素和周圍畫素的關係函式)在周圍畫素比較相似的時候,函式的作用會降低為一元函式,那麼就起不到原始的意願:當前畫素和周圍畫素的關係函式

此論文發現\(f(x_i,x_j)\)不能僅僅的表示為兩者的關係,還應該包含其他部分。論文裡的說法是:此二元函式(\(pairwise\))裡面包含一個一元函式(\(unary\))+一個二元函式(\(pairwise\)

),得分開來表達。

下附圖體現了不同模組表達的函式不同:


2.2 具體實現

2.2.1 理論部分

  • 公式(3)的提出,如何得到公式(3),下附圖論文只是一筆帶過:

補充:

\(key = unary\)\(query=piarwise\)含義的一樣的。

論文使用白化(減均值)進行操作,公式的目的是獲得\(key\)\(query\)之間相關性的最大距離,也就是讓兩個值相互(儘量)獨立,這樣當週圍畫素相似才不影響整體的判斷。

其中,\(q_i,q_j\) 表示\(query\)的當前特徵和周圍特徵,\(k_m,k_n\) 表示\(key\)的當前特徵和周圍特徵。

論文使用點乘表示兩者的相關性,因為寫高斯函式比較複雜,所以簡化操作(見論文A

)。

那麼以下的公式就比較明瞭,筆者進行化解: \(q_i^T*k_m-q_i^T*k_n-k_m^T*q_j\) ,第一項表示兩者的相關性(肯定越大越好),第二項和第三項表示對對方周圍畫素的關聯性(肯定越小越好),我們最大化這個函式,就能保住兩者之間差異性最大化。其實第一項也可以表示成差異性,第二三項表示成關聯性,這樣更容易理解。

以下公式分子是差異性 ,分母是歸一化的求和。

  • 公式(4)作者也是一筆帶過

補充:

論文前面一直說:\(q_i^Tk_j=(q_i-\mu_q)^T(k_j-\mu_k)\) ,為什麼到這裡突然出現後面三項?

因為論文一直在說一件事,\(f(x_i,x_j)\) 不僅僅包含\(q_i^Tk_j\),還影藏的包含了一元函式

一元函式到底是什麼?

既然是未知的,那就全部列出來,\(u_q^Tk_j+q_i^Tu_k+u_q^Tu_k\) ,這裡是上面式子展開的全部組合,具體哪個項的作用具體是什麼?論文未進一步討論。

  • 公式在視覺上的體現(論文3.2節

這部分主要對理論的實際展現,通過label和operate的邊界交集進行視覺化分析

  • 反向推導公式的好處(論文3.3節

通過理論反向推導公式的優勢,反向鏈式求導,add比multi更具有分離性

  • 推導(附錄)

其中hessian矩陣小於0,獲得最大值

2.2.2 具體實現

下圖只是一個整體流程圖,具體實現得結合公式

主要有兩個實現版本,感覺都不全。

g_k = conv(x), g_q = conv(x), g_m=conv(x), g_w=conv(x)

g_k= = g_k - k_mean, g_q = g_q - q_mean

g_pnl = soft_max( g_k * g_q ), g_m = soft_max(g_m * q_mean) #這裡得加上公式裡的內容\(u_q^Tk_j\)

g_dnl = g_pnl + g_m

g_dnl = g_v*g_dnl

x = x + g_dnl

import torch
import torch.nn as nn
from mmcv.cnn import constant_init, normal_init

from ..utils import ConvModule
from mmdet.ops import ContextBlock

from torch.nn.parameter import Parameter

class NonLocal2D(nn.Module):
    """Non-local module.
    See https://arxiv.org/abs/1711.07971 for details.
    Args:
        in_channels (int): Channels of the input feature map.
        reduction (int): Channel reduction ratio.
        use_scale (bool): Whether to scale pairwise_weight by 1/inter_channels.
        conv_cfg (dict): The config dict for convolution layers.
            (only applicable to conv_out)
        norm_cfg (dict): The config dict for normalization layers.
            (only applicable to conv_out)
        mode (str): Options are `embedded_gaussian` and `dot_product`.
    """

    def __init__(self,
                 in_channels,
                 reduction=2,
                 use_scale=True,
                 conv_cfg=None,
                 norm_cfg=None,
                 mode='embedded_gaussian',
                 whiten_type=None,
                 temp=1.0,
                 downsample=False,
                 fixbug=False,
                 learn_t=False,
                 gcb=None):
        super(NonLocal2D, self).__init__()
        self.in_channels = in_channels
        self.reduction = reduction
        self.use_scale = use_scale
        self.inter_channels = in_channels // reduction
        self.mode = mode
        assert mode in ['embedded_gaussian', 'dot_product', 'gaussian']
        if mode == 'gaussian':
            self.with_embedded = False
        else:
            self.with_embedded = True
        self.whiten_type = whiten_type
        assert whiten_type in [None, 'channel', 'bn-like']  # TODO: support more
        self.learn_t = learn_t
        if self.learn_t:
            self.temp = Parameter(torch.Tensor(1))
            self.temp.data.fill_(temp)
        else:
            self.temp = temp
        if downsample:
            self.downsample = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        else:
            self.downsample = None
        self.fixbug=fixbug

        assert gcb is None or isinstance(gcb, dict)
        self.gcb = gcb
        if gcb is not None:
            self.gc_block = ContextBlock(inplanes=in_channels, **gcb)
        else:
            self.gc_block = None

        # g, theta, phi are actually `nn.Conv2d`. Here we use ConvModule for
        # potential usage.
        self.g = ConvModule(
            self.in_channels,
            self.inter_channels,
            kernel_size=1,
            activation=None)
        if self.with_embedded:
            self.theta = ConvModule(
                self.in_channels,
                self.inter_channels,
                kernel_size=1,
                activation=None)
            self.phi = ConvModule(
                self.in_channels,
                self.inter_channels,
                kernel_size=1,
                activation=None)
        self.conv_out = ConvModule(
            self.inter_channels,
            self.in_channels,
            kernel_size=1,
            conv_cfg=conv_cfg,
            norm_cfg=norm_cfg,
            activation=None)

        self.init_weights()

    def init_weights(self, std=0.01, zeros_init=True):
        transform_list = [self.g]
        if self.with_embedded:
            transform_list.extend([self.theta, self.phi])
        for m in transform_list:
            normal_init(m.conv, std=std)
        if zeros_init:
            constant_init(self.conv_out.conv, 0)
        else:
            normal_init(self.conv_out.conv, std=std)

    def embedded_gaussian(self, theta_x, phi_x):
        # pairwise_weight: [N, HxW, HxW]
        pairwise_weight = torch.matmul(theta_x, phi_x)
        if self.use_scale:
            # theta_x.shape[-1] is `self.inter_channels`
            if self.fixbug:
                pairwise_weight /= theta_x.shape[-1]**0.5
            else:
                pairwise_weight /= theta_x.shape[-1]**-0.5
        if self.learn_t:
            pairwise_weight = pairwise_weight * nn.functional.softplus(self.temp) # stable training
        else:
            pairwise_weight = pairwise_weight / self.temp
        pairwise_weight = pairwise_weight.softmax(dim=-1)
        return pairwise_weight

    def gaussian(self, theta_x, phi_x):
        return self.embedded_gaussian(theta_x, phi_x)

    def dot_product(self, theta_x, phi_x):
        # pairwise_weight: [N, HxW, HxW]
        pairwise_weight = torch.matmul(theta_x, phi_x)
        pairwise_weight /= pairwise_weight.shape[-1]
        return pairwise_weight

    def forward(self, x):
        n, _, h, w = x.shape
        if self.downsample:
            down_x = self.downsample(x)
        else:
            down_x = x

        # g_x: [N, H'xW', C], VALUE?
        g_x = self.g(down_x).view(n, self.inter_channels, -1)
        g_x = g_x.permute(0, 2, 1)

        # theta_x: [N, HxW, C], QUERY?
        if self.with_embedded:
            theta_x = self.theta(x).view(n, self.inter_channels, -1)
            theta_x = theta_x.permute(0, 2, 1)
        else:
            theta_x = x.view(n, self.in_channels, -1)
            theta_x = theta_x.permute(0, 2, 1)

        # phi_x: [N, C, H'xW'], KEY?
        if self.with_embedded:
            phi_x = self.phi(down_x).view(n, self.inter_channels, -1)
        else:
            phi_x = x.view(n, self.in_channels, -1)

        # whiten
        if self.whiten_type == "channel":
            theta_x_mean = theta_x.mean(2).unsqueeze(2)
            phi_x_mean = phi_x.mean(2).unsqueeze(2)
            theta_x -= theta_x_mean
            phi_x -= phi_x_mean
        elif self.whiten_type == 'bn-like':
            theta_x_mean = theta_x.mean(2).mean(0).unsqueeze(0).unsqueeze(2)
            phi_x_mean = phi_x.mean(2).mean(0).unsqueeze(0).unsqueeze(2)
            theta_x -= theta_x_mean
            phi_x -= phi_x_mean

        pairwise_func = getattr(self, self.mode)
        # pairwise_weight: [N, HxW, H'xW']
        pairwise_weight = pairwise_func(theta_x, phi_x)

        # y: [N, HxW, C]
        y = torch.matmul(pairwise_weight, g_x)
        # y: [N, C, H, W]
        y = y.permute(0, 2, 1).reshape(n, self.inter_channels, h, w)


        # gc block
        if self.gcb:
            output = self.gc_block(x) + self.conv_out(y)
        else:
            output = x + self.conv_out(y)

        return output
import torch
import torch.nn.functional as F
#from libs import InPlaceABN, InPlaceABNSync
from torch import nn
from torch.nn import init
import math


class _NonLocalNd_bn(nn.Module):

    def __init__(self, dim, inplanes, planes, downsample, use_gn, lr_mult, use_out, out_bn, whiten_type, temperature,
                 with_gc, with_unary):
        assert dim in [1, 2, 3], "dim {} is not supported yet".format(dim)
        # assert whiten_type in ['channel', 'spatial']
        if dim == 3:
            conv_nd = nn.Conv3d
            if downsample:
                max_pool = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))
            else:
                max_pool = None
            bn_nd = nn.BatchNorm3d
        elif dim == 2:
            conv_nd = nn.Conv2d
            if downsample:
                max_pool = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
            else:
                max_pool = None
            bn_nd = nn.BatchNorm2d
        else:
            conv_nd = nn.Conv1d
            if downsample:
                max_pool = nn.MaxPool1d(kernel_size=2, stride=2)
            else:
                max_pool = None
            bn_nd = nn.BatchNorm1d

        super(_NonLocalNd_bn, self).__init__()
        self.conv_query = conv_nd(inplanes, planes, kernel_size=1)
        self.conv_key = conv_nd(inplanes, planes, kernel_size=1)
        if use_out:
            self.conv_value = conv_nd(inplanes, planes, kernel_size=1)
            self.conv_out = conv_nd(planes, inplanes, kernel_size=1, bias=False)
        else:
            self.conv_value = conv_nd(inplanes, inplanes, kernel_size=1, bias=False)
            self.conv_out = None
        if out_bn:
            self.out_bn = nn.BatchNorm2d(inplanes)
        else:
            self.out_bn = None
        if with_gc:
            self.conv_mask = conv_nd(inplanes, 1, kernel_size=1)
        if 'bn_affine' in whiten_type:
            self.key_bn_affine = nn.BatchNorm1d(planes)
            self.query_bn_affine = nn.BatchNorm1d(planes)
        if 'bn' in whiten_type:
            self.key_bn = nn.BatchNorm1d(planes, affine=False)
            self.query_bn = nn.BatchNorm1d(planes, affine=False)
        self.softmax = nn.Softmax(dim=2)
        self.downsample = max_pool
        # self.norm = nn.GroupNorm(num_groups=32, num_channels=inplanes) if use_gn else InPlaceABNSync(num_features=inplanes)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.scale = math.sqrt(planes)
        self.whiten_type = whiten_type
        self.temperature = temperature
        self.with_gc = with_gc
        self.with_unary = with_unary

        self.reset_parameters()
        self.reset_lr_mult(lr_mult)

    def reset_parameters(self):

        for m in self.modules():
            if isinstance(m, nn.Conv3d) or isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d):
                init.normal_(m.weight, 0, 0.01)
                if m.bias is not None:
                    init.zeros_(m.bias)
                m.inited = True
        # init.constant_(self.norm.weight, 0)
        # init.constant_(self.norm.bias, 0)
        # self.norm.inited = True

    def reset_lr_mult(self, lr_mult):
        if lr_mult is not None:
            for m in self.modules():
                m.lr_mult = lr_mult
        else:
            print('not change lr_mult')

    def forward(self, x):
        # [N, C, T, H, W]
        residual = x
        # [N, C, T, H', W']
        if self.downsample is not None:
            input_x = self.downsample(x)
        else:
            input_x = x

        # [N, C', T, H, W]
        query = self.conv_query(x)
        # [N, C', T, H', W']
        key = self.conv_key(input_x)
        value = self.conv_value(input_x)

        # [N, C', H x W]
        query = query.view(query.size(0), query.size(1), -1)
        # [N, C', H' x W']
        key = key.view(key.size(0), key.size(1), -1)
        value = value.view(value.size(0), value.size(1), -1)

        if 'channel' in self.whiten_type:
            key_mean = key.mean(2).unsqueeze(2)
            query_mean = query.mean(2).unsqueeze(2)
            key -= key_mean
            query -= query_mean
        if 'spatial' in self.whiten_type:
            key_mean = key.mean(1).unsqueeze(1)
            query_mean = query.mean(1).unsqueeze(1)
            key -= key_mean
            query -= query_mean
        if 'bn_affine' in self.whiten_type:
            key = self.key_bn_affine(key)
            query = self.query_bn_affine(query)
        if 'bn' in self.whiten_type:
            key = self.key_bn(key)
            query = self.query_bn(query)
        if 'ln_nostd' in self.whiten_type :
            key_mean = key.mean(1).mean(1).view(key.size(0), 1, 1)
            query_mean = query.mean(1).mean(1).view(query.size(0), 1, 1)
            key -= key_mean
            query -= query_mean

        # [N, T x H x W, T x H' x W']
        sim_map = torch.bmm(query.transpose(1, 2), key)
        sim_map = sim_map / self.scale
        sim_map = sim_map / self.temperature
        sim_map = self.softmax(sim_map)

        # [N, T x H x W, C']
        out_sim = torch.bmm(sim_map, value.transpose(1, 2))
        # [N, C', T x H x W]
        out_sim = out_sim.transpose(1, 2)
        # [N, C', T,  H, W]
        out_sim = out_sim.view(out_sim.size(0), out_sim.size(1), *x.size()[2:])
        # if self.norm is not None:
        #     out = self.norm(out)
        out_sim = self.gamma * out_sim
        
        if self.with_unary:
            if query_mean.shape[1] ==1:
                query_mean = query_mean.expand(-1, key.shape[1], -1)
            unary = torch.bmm(query_mean.transpose(1,2),key)
            unary = self.softmax(unary)
            out_unary = torch.bmm(value, unary.permute(0,2,1)).unsqueeze(-1)
            out_sim = out_sim + out_unary

        # out = residual + out_sim

        if self.with_gc:
            # [N, 1, H', W']
            mask = self.conv_mask(input_x)
            # [N, 1, H'x W']
            mask = mask.view(mask.size(0), mask.size(1), -1)
            mask = self.softmax(mask)
            # [N, C', 1, 1]
            out_gc = torch.bmm(value, mask.permute(0, 2, 1)).unsqueeze(-1)
            out_sim = out_sim + out_gc

        # [N, C, T,  H, W]
        if self.conv_out is not None:
            out_sim = self.conv_out(out_sim)
        if self.out_bn:
            out_sim = self.out_bn(out_sim)

        out = out_sim + residual

        return out


class NonLocal2d_bn(_NonLocalNd_bn):

    def __init__(self, inplanes, planes, downsample=True, use_gn=False, lr_mult=None, use_out=False, out_bn=False,
                 whiten_type=['channel'], temperature=1.0, with_gc=False, with_unary=False):
        super(NonLocal2d_bn, self).__init__(dim=2, inplanes=inplanes, planes=planes, downsample=downsample,
                                            use_gn=use_gn, lr_mult=lr_mult, use_out=use_out, out_bn=out_bn,
                                            whiten_type=whiten_type, temperature=temperature, with_gc=with_gc, with_unary=with_unary)