Transformer 中的 attention
阿新 • • 發佈:2022-05-08
Transformer 中的 attention
大火的transformer 本質就是:
*使用attention機制的seq2seq。*
所以它的核心就是attention機制,今天就講attention。直奔程式碼VIT-pytorch:
https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py
中的
class Attention(nn.Module): def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): super().__init__() inner_dim = dim_head * heads project_out = not (heads == 1 and dim_head == dim) self.heads = heads self.scale = dim_head ** -0.5 self.attend = nn.Softmax(dim = -1) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) self.to_out = nn.Sequential( nn.Linear(inner_dim, dim), nn.Dropout(dropout) ) if project_out else nn.Identity() def forward(self, x): qkv = self.to_qkv(x).chunk(3, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale attn = self.attend(dots) out = torch.matmul(attn, v) out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out)
看吧!就是這麼簡單。今天就徹底搞懂這個東西。
先記住attention的這麼幾個點:
- attention和CNN、RNN、FC、GCN等都是一個級別的東西,用來提取特徵;既然是特徵提取,一定有權重(W+B)存在。
- attention的優點:可以像CNN一樣並行運算 + 像RNN一樣通過一層就擁有全域性資訊。有一個東西也可以做到,那就是FC,但是FC有個弱點:對輸入尺寸有限制,說白了不好適應可變輸入資料,這對於序列無疑是非常不友好的。
- pooling也可以實現,但是它是無參的過程。例如點雲資料,就可以用pooling來處理,當然也有一些網路是pooling is all your need。
- 可以像CNN一樣並行運算 ,其實CNN運算也是通過im2col或winograd等轉化為矩陣運算的。
- RNN不能並行,所以通常它處理的資料有“時序”這個特點。既然是“時序”,那麼就不是同一個時刻完成的,所以不能並行化。
綜上所述: attention優點 = CNN並行+RNN全域性資訊+對輸入尺寸(時序長度維度上)沒有限制。
如果你能創造一個擁有上面三點優點的東西出來,你也可以引領潮流。
然後回到程式碼,再熟悉這麼幾個設定:
- batch維度:大家利用同樣的權重和操作提取特徵,可以理解為for迴圈式,相互之間沒有資訊互動;
- multi head維度:同batch類似,不過是利用的不同權重和相同操作提取特徵,最後concate一起使用;
- FC層:是作用在每一個特徵上,類似CNN中的1X1,可以叫“pointwise”,和序列長度沒有關係;因為序列中所有的特徵經過的是同一個FC。
下面看這個圖,看完不懂的可以扇自己了:
attention的順序是:
- 你有長度為n(序列)的序列,每個元素都是一個特徵,每個特徵都是一個向量;
- 每個向量都經過FC1,FC2,FC3獲取到q,k,v三個向量(長度自己定),記住,不同特徵用的是同一個FC1,FC2,FC3。可以說對於一個head,就一組FC1,FC2,FC3。
- 特徵1的q1和所有特徵的k 進行點乘,獲取一串值,注意:和自己的k也進行點乘;點乘向量變標量,表示相似性。多個K可不就是一串標量。
- 3中的那一串值進行softmax操作,作為權重 對所有v加權求和,獲得特徵1輸出;
- 其他所有的特徵和特徵1的操作一樣,注意所有特徵是一塊平行計算的;
- 最後獲取的和輸入一樣長度的特徵序列再經過FC進行長度(特徵維度)調整,也可以不要;
對了,softmax之前不要忘記 除以 qkv長度開方進行scaled,其實就是標準化操作(我覺得可以理解為各種N(BN,GN,LN等))。
就是這麼簡單,你學會了嗎?