注意力增強卷積 程式碼解讀
阿新 • • 發佈:2020-12-15
原論文 Attention Augmented Convolutional Networks
程式碼來源 leaderj1001/Attention-Augmented-Conv2d
匯入模組&cuda載入
import torch
import torch.nn as nn
import torch.nn.functional as F
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
初始化&forward
class AugmentedConv(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, dk, dv, Nh, shape=0, relative=False, stride=1): super(AugmentedConv, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.dk = dk self.dv = dv self.Nh = Nh self.shape = shape self.relative = relative # 是否加入位置編碼 self.stride = stride self.padding = (self.kernel_size - 1) // 2 assert self.Nh != 0, "integer division or modulo by zero, Nh >= 1" assert self.dk % self.Nh == 0, "dk should be divided by Nh. (example: out_channels: 20, dk: 40, Nh: 4)" assert self.dv % self.Nh == 0, "dv should be divided by Nh. (example: out_channels: 20, dv: 4, Nh: 4)" assert stride in [1, 2], str(stride) + " Up to 2 strides are allowed." # 這裡要減去 dv,因為 conv_out 的輸出要和 attn_out 的輸出合併 self.conv_out = nn.Conv2d(self.in_channels, self.out_channels - self.dv, self.kernel_size, stride=stride, padding=self.padding) # 這個卷積操作的目的就是得到 k, q, v, 注意卷積操作包含了計算 X * W_q, X * W_k, X * W_v 的過程 self.qkv_conv = nn.Conv2d(self.in_channels, 2 * self.dk + self.dv, kernel_size=self.kernel_size, stride=stride, padding=self.padding) # attention 的結果仍要作為特徵層傳入卷積層進行特徵提取 self.attn_out = nn.Conv2d(self.dv, self.dv, kernel_size=1, stride=1) if self.relative: # 每個位置的w, h相對位置編碼的可學習引數量均為 2 * [w or h] - 1 self.key_rel_w = nn.Parameter(torch.randn((2 * self.shape - 1, dk // Nh), requires_grad=True)) self.key_rel_h = nn.Parameter(torch.randn((2 * self.shape - 1, dk // Nh), requires_grad=True)) def forward(self, x): """ attention augmented conv 的 “主函式” :param x: 輸入資料,形狀為 (batch_size, in_channels, height, width) :return: 最終輸出,形狀為 (batch, out_channels, height, width) """ # conv_out -> (batch_size, out_channels - dv, height, width) conv_out = self.conv_out(x) batch, _, height, width = conv_out.size() # flat_q, flat_k, flat_v -> (batch_size, Nh, height * width, dvh or dkh) # dvh = dv / Nh, dkh = dk / Nh # q, k, v -> (batch_size, Nh, height, width, dv or dk) flat_q, flat_k, flat_v, q, k, v = self.compute_flat_qkv(x, self.dk, self.dv, self.Nh) logits = torch.matmul(flat_q.transpose(2, 3), flat_k) if self.relative: h_rel_logits, w_rel_logits = self.relative_logits(q) logits += h_rel_logits logits += w_rel_logits weights = F.softmax(logits, dim=-1) # attn_out -> (batch_size, Nh, height * width, dvh) attn_out = torch.matmul(weights, flat_v.transpose(2, 3)) attn_out = torch.reshape(attn_out, (batch, self.Nh, self.dv // self.Nh, height, width)) # combine_heads_2d -> (batch_size, dv, height, width) attn_out = self.combine_heads_2d(attn_out) attn_out = self.attn_out(attn_out) # 將注意力運算結果作為特徵層傳入卷積層 return torch.cat((conv_out, attn_out), dim=1)
功能函式
def compute_flat_qkv(self, x, dk, dv, Nh): """ 計算 q, k, v 以及每個 head 的 q, k, v :param x: 輸入資料,形狀為 (batch_size, in_channels, height, width) :param dk: q, k 的維度 :param dv: v 的維度 :param Nh: 有多少個 head :return: flat_q, flat_k, flat_v, q, k, v """ qkv = self.qkv_conv(x) # 利用卷積操作求 q, k, v, 包含了計算 X * W_q, X * W_k, X * W_v 的過程 N, _, H, W = qkv.size() q, k, v = torch.split(qkv, [dk, dk, dv], dim=1) # 將卷積輸出按 channel 劃分為 q, k, v # 將single head 改為 multi-head q = self.split_heads_2d(q, Nh) k = self.split_heads_2d(k, Nh) v = self.split_heads_2d(v, Nh) dkh = dk // Nh q *= dkh ** -0.5 # 得到每個 head 的 q, k, v flat_q = torch.reshape(q, (N, Nh, dk // Nh, H * W)) flat_k = torch.reshape(k, (N, Nh, dk // Nh, H * W)) flat_v = torch.reshape(v, (N, Nh, dv // Nh, H * W)) return flat_q, flat_k, flat_v, q, k, v def split_heads_2d(self, x, Nh): """ 劃分 head :param x: q or k or v :param Nh: head 的數量,必須要能整除 q, k, v 的 channel 維度數 :return: reshape 後的 q, k, v """ batch, channels, height, width = x.size() ret_shape = (batch, Nh, channels // Nh, height, width) split = torch.reshape(x, ret_shape) return split def combine_heads_2d(self, x): """ 將所有 head 的輸出組合到一起 :param x: 包含所有 head 的輸出 :return: 組合後的輸出 """ batch, Nh, dv, H, W = x.size() ret_shape = (batch, Nh * dv, H, W) return torch.reshape(x, ret_shape)
位置編碼
def relative_logits(self, q): """ 計算相對位置編碼 :param q: q :return: h 和 w 的位置編碼 """ B, Nh, dk, H, W = q.size() # q -> (B, Nh, H, W, dk) q = torch.transpose(q, 2, 4).transpose(2, 3) # 分別計算 w 與 h 的一維編碼 rel_logits_w = self.relative_logits_1d(q, self.key_rel_w, H, W, Nh, "w") rel_logits_h = self.relative_logits_1d(torch.transpose(q, 2, 3), self.key_rel_h, W, H, Nh, "h") return rel_logits_h, rel_logits_w def relative_logits_1d(self, q, rel_k, H, W, Nh, case): """ 計算一維位置編碼 :param q: q,維度為(B, Nh, H, W, dk) :param rel_k: 位置編碼的可學習引數,形狀為為 (2 * [w or h] - 1, dk // Nh) :param H: 輸入特徵高度 :param W: 輸入特徵寬度 :param Nh: head 數量 :param case: 區分 w 還是 h 的位置編碼 :return: 位置編碼,形狀為 (B, Nh, H * W, H * W) """ # 使用愛因斯坦求和約定,實現批量矩陣乘法 rel_logits = torch.einsum('bhxyd,md->bhxym', q, rel_k) # 因為是一維位置編碼 (w or h),所以另一個維度用不上 rel_logits = torch.reshape(rel_logits, (-1, Nh * H, W, 2 * W - 1)) # 加入位置資訊 rel_logits = self.rel_to_abs(rel_logits) # 下面的操作都是為了最後能產生形狀為 (B, Nh, H * W, H * W) 的輸出,以便於與 logit 相加 # 詳見 forward 函式第 18 行 rel_logits = torch.reshape(rel_logits, (-1, Nh, H, W, W)) rel_logits = torch.unsqueeze(rel_logits, dim=3) rel_logits = rel_logits.repeat((1, 1, 1, H, 1, 1)) if case == "w": rel_logits = torch.transpose(rel_logits, 3, 4) elif case == "h": rel_logits = torch.transpose(rel_logits, 2, 4).transpose(4, 5).transpose(3, 5) # 改變形狀以便於與 logit 相加 rel_logits = torch.reshape(rel_logits, (-1, Nh, H * W, H * W)) return rel_logits def rel_to_abs(self, x): """ 相對 to 絕對,在位置編碼中加入絕對位置資訊 :param x: 原始位置編碼,形狀為 (B, Nh * H, W, 2 * W - 1) :return: 位置編碼,形狀為 (B, Nh * H, W, W) """ B, Nh, L, _ = x.size() # '0' 即絕對位置資訊,此後所有操作都是為了讓同一 [行 or 列] 的每個點的位置編碼的 '0' 出現的位置不同 # 在最後一個維度的末尾,即每隔 2L - 1 的位置加入 0, # 這就是為什麼 key_rel_[w or h],即可學習引數有 2 * [w or h] - 1 個 col_pad = torch.zeros((B, Nh, L, 1)).to(x) x = torch.cat((x, col_pad), dim=3) # 每個 head 加入 L - 1 個 0, 為了讓每一 [行 or 列] 的 '0' 錯位 flat_x = torch.reshape(x, (B, Nh, L * 2 * L)) flat_pad = torch.zeros((B, Nh, L - 1)).to(x) flat_x_padded = torch.cat((flat_x, flat_pad), dim=2) # 將 (L * 2 * L) + (L - 1) 個編碼元素重新組織,使其形狀為為 (L + 1, 2 * L - 1) # 目的是讓 '0' 錯位,這樣每一 [行 or 列] 的點的位置編碼中 '0' 出現的位置不一樣 # 相當於嵌入了絕對位置資訊 final_x = torch.reshape(flat_x_padded, (B, Nh, L + 1, 2 * L - 1)) # reshape 以便於後續操作 final_x = final_x[:, :, :L, L - 1:] return final_x
使用示例
tmp = torch.randn((16, 3, 32, 32)).to(device)
augmented_conv1 = AugmentedConv(in_channels=3, out_channels=20, kernel_size=3, dk=40, dv=4, Nh=4,
relative=True, padding=1, stride=2, shape=16).to(device)
conv_out1 = augmented_conv1(tmp)
print(conv_out1.shape)
for name, param in augmented_conv1.named_parameters():
print('parameter name: ', name)
augmented_conv2 = AugmentedConv(in_channels=3, out_channels=20, kernel_size=3, dk=40, dv=4, Nh=4,
relative=True, padding=1, stride=1, shape=32).to(device)
conv_out2 = augmented_conv2(tmp)
print(conv_out2.shape)