Disentangled Non-Local Neural Networks
目錄
Disentangled Non-Local Neural Networks
一. 論文簡介
理論(部分感覺不是很合理,不懂大佬思維)和實踐相結合的論文,感覺很不錯,第一次讀很難讀懂。
解決區域性感受野的問題,是上一篇論文 的擴充套件
主要做的貢獻如下(可能之前有人已提出):
- 解決區域性感受野,設計一個Block
二. 模組詳解
2.1 論文思路簡介
全部基於論文的內容進行改進,下述將論文A進行代替:
論文A主要是表達一個函式\(f(x_i,x_j)*f(x_j)\) ,表示當前畫素的表達需要依靠周圍畫素,前者表示周圍畫素的權重,後者表示當前畫素進行的處理(你也可以直接化簡函式\(f(x_i,x_j)*x_i\))
論文A中的缺點是 \(f(x_i,x_j)\)(當前畫素和周圍畫素的關係函式)在周圍畫素比較相似的時候,函式的作用會降低為一元函式,那麼就起不到原始的意願:當前畫素和周圍畫素的關係函式
此論文發現\(f(x_i,x_j)\)不能僅僅的表示為兩者的關係,還應該包含其他部分。論文裡的說法是:此二元函式(\(pairwise\))裡面包含一個一元函式(\(unary\))+一個二元函式(\(pairwise\)
),得分開來表達。下附圖體現了不同模組表達的函式不同:
2.2 具體實現
2.2.1 理論部分
- 公式(3)的提出,如何得到公式(3),下附圖論文只是一筆帶過:
補充:
\(key = unary\) ,\(query=piarwise\)含義的一樣的。
論文使用白化(減均值)進行操作,公式的目的是獲得\(key\)和\(query\)之間相關性的最大距離,也就是讓兩個值相互(儘量)獨立,這樣當週圍畫素相似才不影響整體的判斷。
其中,\(q_i,q_j\) 表示\(query\)的當前特徵和周圍特徵,\(k_m,k_n\) 表示\(key\)的當前特徵和周圍特徵。
論文使用點乘表示兩者的相關性,因為寫高斯函式比較複雜,所以簡化操作(見論文A
)。那麼以下的公式就比較明瞭,筆者進行化解: \(q_i^T*k_m-q_i^T*k_n-k_m^T*q_j\) ,第一項表示兩者的相關性(肯定越大越好),第二項和第三項表示對對方周圍畫素的關聯性(肯定越小越好),我們最大化這個函式,就能保住兩者之間差異性最大化。其實第一項也可以表示成差異性,第二三項表示成關聯性,這樣更容易理解。
以下公式分子是差異性 ,分母是歸一化的求和。
- 公式(4)作者也是一筆帶過
補充:
論文前面一直說:\(q_i^Tk_j=(q_i-\mu_q)^T(k_j-\mu_k)\) ,為什麼到這裡突然出現後面三項?
因為論文一直在說一件事,\(f(x_i,x_j)\) 不僅僅包含\(q_i^Tk_j\),還影藏的包含了一元函式。
一元函式到底是什麼?
既然是未知的,那就全部列出來,\(u_q^Tk_j+q_i^Tu_k+u_q^Tu_k\) ,這裡是上面式子展開的全部組合,具體哪個項的作用具體是什麼?論文未進一步討論。
- 公式在視覺上的體現(論文3.2節)
這部分主要對理論的實際展現,通過label和operate的邊界交集進行視覺化分析
- 反向推導公式的好處(論文3.3節)
通過理論反向推導公式的優勢,反向鏈式求導,add比multi更具有分離性
- 推導(附錄)
其中hessian矩陣小於0,獲得最大值
2.2.2 具體實現
主要有兩個實現版本,感覺都不全。下圖只是一個整體流程圖,具體實現得結合公式
g_k = conv(x), g_q = conv(x), g_m=conv(x), g_w=conv(x)
g_k= = g_k - k_mean, g_q = g_q - q_mean
g_pnl = soft_max( g_k * g_q ), g_m = soft_max(g_m * q_mean) #這裡得加上公式裡的內容\(u_q^Tk_j\)
g_dnl = g_pnl + g_m
g_dnl = g_v*g_dnl
x = x + g_dnl
import torch
import torch.nn as nn
from mmcv.cnn import constant_init, normal_init
from ..utils import ConvModule
from mmdet.ops import ContextBlock
from torch.nn.parameter import Parameter
class NonLocal2D(nn.Module):
"""Non-local module.
See https://arxiv.org/abs/1711.07971 for details.
Args:
in_channels (int): Channels of the input feature map.
reduction (int): Channel reduction ratio.
use_scale (bool): Whether to scale pairwise_weight by 1/inter_channels.
conv_cfg (dict): The config dict for convolution layers.
(only applicable to conv_out)
norm_cfg (dict): The config dict for normalization layers.
(only applicable to conv_out)
mode (str): Options are `embedded_gaussian` and `dot_product`.
"""
def __init__(self,
in_channels,
reduction=2,
use_scale=True,
conv_cfg=None,
norm_cfg=None,
mode='embedded_gaussian',
whiten_type=None,
temp=1.0,
downsample=False,
fixbug=False,
learn_t=False,
gcb=None):
super(NonLocal2D, self).__init__()
self.in_channels = in_channels
self.reduction = reduction
self.use_scale = use_scale
self.inter_channels = in_channels // reduction
self.mode = mode
assert mode in ['embedded_gaussian', 'dot_product', 'gaussian']
if mode == 'gaussian':
self.with_embedded = False
else:
self.with_embedded = True
self.whiten_type = whiten_type
assert whiten_type in [None, 'channel', 'bn-like'] # TODO: support more
self.learn_t = learn_t
if self.learn_t:
self.temp = Parameter(torch.Tensor(1))
self.temp.data.fill_(temp)
else:
self.temp = temp
if downsample:
self.downsample = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
else:
self.downsample = None
self.fixbug=fixbug
assert gcb is None or isinstance(gcb, dict)
self.gcb = gcb
if gcb is not None:
self.gc_block = ContextBlock(inplanes=in_channels, **gcb)
else:
self.gc_block = None
# g, theta, phi are actually `nn.Conv2d`. Here we use ConvModule for
# potential usage.
self.g = ConvModule(
self.in_channels,
self.inter_channels,
kernel_size=1,
activation=None)
if self.with_embedded:
self.theta = ConvModule(
self.in_channels,
self.inter_channels,
kernel_size=1,
activation=None)
self.phi = ConvModule(
self.in_channels,
self.inter_channels,
kernel_size=1,
activation=None)
self.conv_out = ConvModule(
self.inter_channels,
self.in_channels,
kernel_size=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
activation=None)
self.init_weights()
def init_weights(self, std=0.01, zeros_init=True):
transform_list = [self.g]
if self.with_embedded:
transform_list.extend([self.theta, self.phi])
for m in transform_list:
normal_init(m.conv, std=std)
if zeros_init:
constant_init(self.conv_out.conv, 0)
else:
normal_init(self.conv_out.conv, std=std)
def embedded_gaussian(self, theta_x, phi_x):
# pairwise_weight: [N, HxW, HxW]
pairwise_weight = torch.matmul(theta_x, phi_x)
if self.use_scale:
# theta_x.shape[-1] is `self.inter_channels`
if self.fixbug:
pairwise_weight /= theta_x.shape[-1]**0.5
else:
pairwise_weight /= theta_x.shape[-1]**-0.5
if self.learn_t:
pairwise_weight = pairwise_weight * nn.functional.softplus(self.temp) # stable training
else:
pairwise_weight = pairwise_weight / self.temp
pairwise_weight = pairwise_weight.softmax(dim=-1)
return pairwise_weight
def gaussian(self, theta_x, phi_x):
return self.embedded_gaussian(theta_x, phi_x)
def dot_product(self, theta_x, phi_x):
# pairwise_weight: [N, HxW, HxW]
pairwise_weight = torch.matmul(theta_x, phi_x)
pairwise_weight /= pairwise_weight.shape[-1]
return pairwise_weight
def forward(self, x):
n, _, h, w = x.shape
if self.downsample:
down_x = self.downsample(x)
else:
down_x = x
# g_x: [N, H'xW', C], VALUE?
g_x = self.g(down_x).view(n, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
# theta_x: [N, HxW, C], QUERY?
if self.with_embedded:
theta_x = self.theta(x).view(n, self.inter_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
else:
theta_x = x.view(n, self.in_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
# phi_x: [N, C, H'xW'], KEY?
if self.with_embedded:
phi_x = self.phi(down_x).view(n, self.inter_channels, -1)
else:
phi_x = x.view(n, self.in_channels, -1)
# whiten
if self.whiten_type == "channel":
theta_x_mean = theta_x.mean(2).unsqueeze(2)
phi_x_mean = phi_x.mean(2).unsqueeze(2)
theta_x -= theta_x_mean
phi_x -= phi_x_mean
elif self.whiten_type == 'bn-like':
theta_x_mean = theta_x.mean(2).mean(0).unsqueeze(0).unsqueeze(2)
phi_x_mean = phi_x.mean(2).mean(0).unsqueeze(0).unsqueeze(2)
theta_x -= theta_x_mean
phi_x -= phi_x_mean
pairwise_func = getattr(self, self.mode)
# pairwise_weight: [N, HxW, H'xW']
pairwise_weight = pairwise_func(theta_x, phi_x)
# y: [N, HxW, C]
y = torch.matmul(pairwise_weight, g_x)
# y: [N, C, H, W]
y = y.permute(0, 2, 1).reshape(n, self.inter_channels, h, w)
# gc block
if self.gcb:
output = self.gc_block(x) + self.conv_out(y)
else:
output = x + self.conv_out(y)
return output
import torch
import torch.nn.functional as F
#from libs import InPlaceABN, InPlaceABNSync
from torch import nn
from torch.nn import init
import math
class _NonLocalNd_bn(nn.Module):
def __init__(self, dim, inplanes, planes, downsample, use_gn, lr_mult, use_out, out_bn, whiten_type, temperature,
with_gc, with_unary):
assert dim in [1, 2, 3], "dim {} is not supported yet".format(dim)
# assert whiten_type in ['channel', 'spatial']
if dim == 3:
conv_nd = nn.Conv3d
if downsample:
max_pool = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))
else:
max_pool = None
bn_nd = nn.BatchNorm3d
elif dim == 2:
conv_nd = nn.Conv2d
if downsample:
max_pool = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
else:
max_pool = None
bn_nd = nn.BatchNorm2d
else:
conv_nd = nn.Conv1d
if downsample:
max_pool = nn.MaxPool1d(kernel_size=2, stride=2)
else:
max_pool = None
bn_nd = nn.BatchNorm1d
super(_NonLocalNd_bn, self).__init__()
self.conv_query = conv_nd(inplanes, planes, kernel_size=1)
self.conv_key = conv_nd(inplanes, planes, kernel_size=1)
if use_out:
self.conv_value = conv_nd(inplanes, planes, kernel_size=1)
self.conv_out = conv_nd(planes, inplanes, kernel_size=1, bias=False)
else:
self.conv_value = conv_nd(inplanes, inplanes, kernel_size=1, bias=False)
self.conv_out = None
if out_bn:
self.out_bn = nn.BatchNorm2d(inplanes)
else:
self.out_bn = None
if with_gc:
self.conv_mask = conv_nd(inplanes, 1, kernel_size=1)
if 'bn_affine' in whiten_type:
self.key_bn_affine = nn.BatchNorm1d(planes)
self.query_bn_affine = nn.BatchNorm1d(planes)
if 'bn' in whiten_type:
self.key_bn = nn.BatchNorm1d(planes, affine=False)
self.query_bn = nn.BatchNorm1d(planes, affine=False)
self.softmax = nn.Softmax(dim=2)
self.downsample = max_pool
# self.norm = nn.GroupNorm(num_groups=32, num_channels=inplanes) if use_gn else InPlaceABNSync(num_features=inplanes)
self.gamma = nn.Parameter(torch.zeros(1))
self.scale = math.sqrt(planes)
self.whiten_type = whiten_type
self.temperature = temperature
self.with_gc = with_gc
self.with_unary = with_unary
self.reset_parameters()
self.reset_lr_mult(lr_mult)
def reset_parameters(self):
for m in self.modules():
if isinstance(m, nn.Conv3d) or isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d):
init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
init.zeros_(m.bias)
m.inited = True
# init.constant_(self.norm.weight, 0)
# init.constant_(self.norm.bias, 0)
# self.norm.inited = True
def reset_lr_mult(self, lr_mult):
if lr_mult is not None:
for m in self.modules():
m.lr_mult = lr_mult
else:
print('not change lr_mult')
def forward(self, x):
# [N, C, T, H, W]
residual = x
# [N, C, T, H', W']
if self.downsample is not None:
input_x = self.downsample(x)
else:
input_x = x
# [N, C', T, H, W]
query = self.conv_query(x)
# [N, C', T, H', W']
key = self.conv_key(input_x)
value = self.conv_value(input_x)
# [N, C', H x W]
query = query.view(query.size(0), query.size(1), -1)
# [N, C', H' x W']
key = key.view(key.size(0), key.size(1), -1)
value = value.view(value.size(0), value.size(1), -1)
if 'channel' in self.whiten_type:
key_mean = key.mean(2).unsqueeze(2)
query_mean = query.mean(2).unsqueeze(2)
key -= key_mean
query -= query_mean
if 'spatial' in self.whiten_type:
key_mean = key.mean(1).unsqueeze(1)
query_mean = query.mean(1).unsqueeze(1)
key -= key_mean
query -= query_mean
if 'bn_affine' in self.whiten_type:
key = self.key_bn_affine(key)
query = self.query_bn_affine(query)
if 'bn' in self.whiten_type:
key = self.key_bn(key)
query = self.query_bn(query)
if 'ln_nostd' in self.whiten_type :
key_mean = key.mean(1).mean(1).view(key.size(0), 1, 1)
query_mean = query.mean(1).mean(1).view(query.size(0), 1, 1)
key -= key_mean
query -= query_mean
# [N, T x H x W, T x H' x W']
sim_map = torch.bmm(query.transpose(1, 2), key)
sim_map = sim_map / self.scale
sim_map = sim_map / self.temperature
sim_map = self.softmax(sim_map)
# [N, T x H x W, C']
out_sim = torch.bmm(sim_map, value.transpose(1, 2))
# [N, C', T x H x W]
out_sim = out_sim.transpose(1, 2)
# [N, C', T, H, W]
out_sim = out_sim.view(out_sim.size(0), out_sim.size(1), *x.size()[2:])
# if self.norm is not None:
# out = self.norm(out)
out_sim = self.gamma * out_sim
if self.with_unary:
if query_mean.shape[1] ==1:
query_mean = query_mean.expand(-1, key.shape[1], -1)
unary = torch.bmm(query_mean.transpose(1,2),key)
unary = self.softmax(unary)
out_unary = torch.bmm(value, unary.permute(0,2,1)).unsqueeze(-1)
out_sim = out_sim + out_unary
# out = residual + out_sim
if self.with_gc:
# [N, 1, H', W']
mask = self.conv_mask(input_x)
# [N, 1, H'x W']
mask = mask.view(mask.size(0), mask.size(1), -1)
mask = self.softmax(mask)
# [N, C', 1, 1]
out_gc = torch.bmm(value, mask.permute(0, 2, 1)).unsqueeze(-1)
out_sim = out_sim + out_gc
# [N, C, T, H, W]
if self.conv_out is not None:
out_sim = self.conv_out(out_sim)
if self.out_bn:
out_sim = self.out_bn(out_sim)
out = out_sim + residual
return out
class NonLocal2d_bn(_NonLocalNd_bn):
def __init__(self, inplanes, planes, downsample=True, use_gn=False, lr_mult=None, use_out=False, out_bn=False,
whiten_type=['channel'], temperature=1.0, with_gc=False, with_unary=False):
super(NonLocal2d_bn, self).__init__(dim=2, inplanes=inplanes, planes=planes, downsample=downsample,
use_gn=use_gn, lr_mult=lr_mult, use_out=use_out, out_bn=out_bn,
whiten_type=whiten_type, temperature=temperature, with_gc=with_gc, with_unary=with_unary)