1. 程式人生 > 其它 >Transformer 中的 attention

Transformer 中的 attention

Transformer 中的 attention

轉自Transformer中的attention,看完不懂扇我臉

大火的transformer 本質就是:

*使用attention機制的seq2seq。*

所以它的核心就是attention機制,今天就講attention。直奔程式碼VIT-pytorch:

https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py

中的

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

看吧!就是這麼簡單。今天就徹底搞懂這個東西。

先記住attention的這麼幾個點:

  • attention和CNN、RNN、FC、GCN等都是一個級別的東西,用來提取特徵;既然是特徵提取,一定有權重(W+B)存在。
  • attention的優點:可以像CNN一樣並行運算 + 像RNN一樣通過一層就擁有全域性資訊。有一個東西也可以做到,那就是FC,但是FC有個弱點:對輸入尺寸有限制,說白了不好適應可變輸入資料,這對於序列無疑是非常不友好的。
  • pooling也可以實現,但是它是無參的過程。例如點雲資料,就可以用pooling來處理,當然也有一些網路是pooling is all your need。
  • 可以像CNN一樣並行運算 ,其實CNN運算也是通過im2col或winograd等轉化為矩陣運算的。
  • RNN不能並行,所以通常它處理的資料有“時序”這個特點。既然是“時序”,那麼就不是同一個時刻完成的,所以不能並行化。

綜上所述: attention優點 = CNN並行+RNN全域性資訊+對輸入尺寸(時序長度維度上)沒有限制。

如果你能創造一個擁有上面三點優點的東西出來,你也可以引領潮流。

然後回到程式碼,再熟悉這麼幾個設定:

  • batch維度:大家利用同樣的權重和操作提取特徵,可以理解為for迴圈式,相互之間沒有資訊互動;
  • multi head維度:同batch類似,不過是利用的不同權重和相同操作提取特徵,最後concate一起使用;
  • FC層:是作用在每一個特徵上,類似CNN中的1X1,可以叫“pointwise”,和序列長度沒有關係;因為序列中所有的特徵經過的是同一個FC。

下面看這個圖,看完不懂的可以扇自己了:

attention的順序是:

  1. 你有長度為n(序列)的序列,每個元素都是一個特徵,每個特徵都是一個向量;
  2. 每個向量都經過FC1,FC2,FC3獲取到q,k,v三個向量(長度自己定),記住,不同特徵用的是同一個FC1,FC2,FC3。可以說對於一個head,就一組FC1,FC2,FC3。
  3. 特徵1的q1和所有特徵的k 進行點乘,獲取一串值,注意:和自己的k也進行點乘;點乘向量變標量,表示相似性。多個K可不就是一串標量。
  4. 3中的那一串值進行softmax操作,作為權重 對所有v加權求和,獲得特徵1輸出;
  5. 其他所有的特徵和特徵1的操作一樣,注意所有特徵是一塊平行計算的;
  6. 最後獲取的和輸入一樣長度的特徵序列再經過FC進行長度(特徵維度)調整,也可以不要;

對了,softmax之前不要忘記 除以 qkv長度開方進行scaled,其實就是標準化操作(我覺得可以理解為各種N(BN,GN,LN等))。

就是這麼簡單,你學會了嗎?