1. 程式人生 > 其它 >用於Transformer的6種注意力的數學原理和程式碼實現

用於Transformer的6種注意力的數學原理和程式碼實現

Transformer 的出色表現讓注意力機制出現在深度學習的各處。本文整理了深度學習中最常用的6種注意力機制的數學原理和程式碼實現。

1、Full Attention

2017的《Attention is All You Need》中的編碼器-解碼器結構實現中提出。它結構並不複雜,所以不難理解。

上圖 1.左側顯示了 Scaled Dot-Product Attention 的機制。當我們有多個注意力時,我們稱之為多頭注意力(右),這也是最常見的注意力的形式公式如下:

公式1

這裡Q(Query)、K(Key)和V(values)被認為是它的輸入,dₖ(輸入維度)被用來降低複雜度和計算成本。這個公式可以說是深度學習中注意力機制發展的開端。下面我們看一下它的程式碼:

  1. class FullAttention(nn.Module):
  2. def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
  3. super(FullAttention, self).__init__()
  4. self.scale = scale
  5. self.mask_flag = mask_flag
  6. self.output_attention = output_attention
  7. self.dropout = nn.Dropout(attention_dropout)
  8. def forward(self, queries, keys, values, attn_mask):
  9. B, L, H, E = queries.shape
  10. _, S, _, D = values.shape
  11. scale = self.scale or 1. / sqrt(E)
  12. scores = torch.einsum("blhe,bshe->bhls", queries, keys)
  13. if self.mask_flag:
  14. if attn_mask is None:
  15. attn_mask = TriangularCausalMask(B, L, device=queries.device)
  16. scores.masked_fill_(attn_mask.mask, -np.inf)
  17. A = self.dropout(torch.softmax(scale * scores, dim=-1))
  18. V = torch.einsum("bhls,bshd->blhd", A, values)
  19. if self.output_attention:
  20. return (V.contiguous(), A)
  21. else:
  22. return (V.contiguous(), None)

2、ProbSparse Attention

藉助“Transformer Dissection: A Unified Understanding of Transformer's Attention via the lens of Kernel”中的資訊我們可以將公式修改為下面的公式2。第i個query的attention就被定義為一個概率形式的核平滑方法(kernel smoother):

公式2

從公式 2,我們可以定義第 i 個查詢的稀疏度測量如下:

完整文章:

https://www.overfit.cn/post/739299d8be4e4ddc8f5804b37c6c82ad