einsum函式介紹-張量常用操作
pytorch文件說明:\(torch.einsum(\)\(*equation*\)$, $$operands*$$)$ 使用基於愛因斯坦求和約定的符號,將輸入operands的元素沿指定的維數求和。einsum允許計算許多常見的多維線性代數陣列運算,方法是基於愛因斯坦求和約定以簡寫格式表示它們。主要是省略了求和號,總體思路是在箭頭左邊用一些下標標記輸入operands的每個維度,並在箭頭右邊定義哪些下標是輸出的一部分。通過將operands元素與下標不屬於輸出的維度的乘積求和來計算輸出。其方便之處在於可以直接通過求和公式寫出運算程式碼。**
兩個基本概念,自由索引(Free indices
- 自由索引,出現在箭頭右邊的索引
- 求和索引,只出現在箭頭左邊的索引,表示中間計算結果需要這個維度上求和之後才能得到輸出,
單運算元
獲取對角線元素diagonal
einsum 可以不做求和。舉個例子,獲取二維方陣的對角線元素,結果放入一維向量。
\[A_i = B_{ii} \]上面,A 是一維向量,B 是二維方陣。使用 einsum 記法,可以寫作 ii->i
torch.einsum('ii->i', torch.randn(4, 4)) # 以下操作互相等價 a = torch.randn(4,4) c = torch.einsum('ii->i', a) c = torch.diagonal(a, 0)
跡trace
求解矩陣的跡(trace),即對角線元素的和。
\[t = \Sigma_{i=1}^{n} A_{ii} \]t 是常量,A 是二維方陣。按照前面的做法,省略 ΣΣ,左右兩邊對調,省去矩陣和 t,剩下的就是ii->
或省略箭頭ii
torch.einsum('ii', torch.randn(4, 4))
矩陣轉置
\[A_{ij} = B_{ji} \]A 和 B 都是二維方陣。einsum 可以表達為 ij->ji
。
torch.einsum('ij -> ji',a)
pytorch 中,還支援省略前面的維度。比如,只轉置最後兩個維度,可以表達為 ...ij->...ji
A = torch.randn(2, 3, 4, 5)
torch.einsum('...ij->...ji', A).shape
# torch.Size([2, 3, 5, 4])
# 等價操作
A.permute(0,1,3,2)
A.transpose(2,3)
求和
\[b=\sum_{i} \sum_{j} A_{i j}=A_{i j} \]a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->', [a])
tensor(15.)
列求和:
\[b_{j}=\sum_{i} A_{i j}=A_{i j} \]a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->j', [a])
tensor([ 3., 5., 7.])
# 等價操作
torch.sum(a, 0) # (dim引數0) means the dimension or dimensions to reduce.
雙運算元
矩陣乘法
\[A_{ij} = \Sigma_{k=1}^{n} B_{ik} C_{kj} \]第一個學習的 einsum 表示式是,ik,kj->ij
。前面提到過,愛因斯坦求和記法可以理解為懶人求和記法。將上述公式中的 ΣΣ 去掉,並且將左右兩邊對調一下,省去矩陣之後,剩下的就是 ik,kj->ij
了。
torch.einsum('ik,kj->ij', a, b)
# 可用兩個矩陣測試以下矩陣乘法操作互相等價
a = torch.randn(2,3)
b = torch.randn(3,4)
c = torch.matmul(a,b)
c = torch.einsum('ik,kj->ij', a, b)
c = a.mm(b)
c = torch.mm(a, b)
c = a @ b
矩陣-向量相乘
\[c_{i}=\sum_{k} A_{i k} b_{k}=A_{i k} b_{k} \]a = torch.arange(6).reshape(2, 3)
b = torch.arange(3)
torch.einsum('ik,k->i', [a, b])
tensor([ 5., 14.])
批量矩陣乘 batch matrix multiplication
\[C_{bik}=\sum_{k} A_{bij} B_{bjk}=A_{bij} B_{bjk} \]>>> As = torch.randn(3,2,5)
>>> Bs = torch.randn(3,5,4)
>>> torch.einsum('bij,bjk->bik', As, Bs)
tensor([[[-1.0564, -1.5904, 3.2023, 3.1271],
[-1.6706, -0.8097, -0.8025, -2.1183]],
[[ 4.2239, 0.3107, -0.5756, -0.2354],
[-1.4558, -0.3460, 1.5087, -0.8530]],
[[ 2.8153, 1.8787, -4.3839, -1.2112],
[ 0.3728, -2.1131, 0.0921, 0.8305]]])
# 等價操作
torch.bmm(As, Bs)
向量內積 dot
\[c=\sum_{i} a_{i} b_{i}=a_{i} b_{i} \]a = torch.arange(3)
b = torch.arange(3,6) # [3, 4, 5]
torch.einsum('i,i->', [a, b])
# tensor(14.)
# 等價操作
torch.dot(a, b)
矩陣內積 dot
\[c=\sum_{i} \sum_{j} A_{i j} B_{i j}=A_{i j} B_{i j} \]a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
torch.einsum('ij,ij->', [a, b])
tensor(145.)
哈達瑪積
\[C_{i j}=A_{i j} B_{i j} \]a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
torch.einsum('ij,ij->ij', [a, b])
tensor([[ 0., 7., 16.],
[ 27., 40., 55.]])
外積 outer
\[C_{i j}=a_{i} b_{j} \]a = torch.arange(3)
b = torch.arange(3,7)
torch.einsum('i,j->ij', [a, b])
tensor([[ 0., 0., 0., 0.],
[ 3., 4., 5., 6.],
[ 6., 8., 10., 12.]])
einsum規則總結:
- 表示式由輸入和輸出兩部分組成。例子,
ij->ji
- 輸出可以省略,箭頭也可以省略。輸入中僅出現一次的字元將按照字母序構成輸出。例子,
ba
完整的表示式是ba->ab
- 輸入中多次出現的字元,將被用作求和。例子,
kj,ji
完整的表示式是kj,ji->ik
,矩陣乘法再相乘。 - 輸出可以指定,但是輸出中的每個字元必須在輸入中出現至少一次,輸出的每個字元在輸出中只能出現最多一次。例子,
ab->aa
是非法的,ab->c
是非法的,ab->a
是合法的。 - 省略符
...
是用來跳過部分維度。例子,...ij,...jk
表示 batch 矩陣乘法。 - 在輸出沒有指定的情況下,省略符優先順序高於普通字元。例子,
b...a
完整的表示式是b...a->...ab
,可以將一個形狀為(a,b,c)
的矩陣變為形狀為(b,c,a)
的矩陣。 - 允許多個矩陣輸入,表示式中使用逗號分開不同矩陣輸入的下標。例子,
i,i,i
表示將三個一維向量按位相乘,並相加。 - 除了箭頭,其他任何地方都可以加空格。例子,
i j , j k -> ik
是合法的,ij,jk - > ik
是非法的。 - 輸入的表示式,維度需要和輸入的矩陣對上,不能多也不能少。比如一個 shape 為
(4,3,3)
的矩陣,表示式ab->a
是非法的,abc->
是合法的。
實際使用
實現multi headed attention
https://nn.labml.ai/transformers/mha.html
如何優雅地實現多頭自注意力
計算注意力score:
\[Q K^{\top} or S_{i j b h}=\sum_{d} Q_{i b h d} K_{j b h d} \]# q k v均為 [seq_len, batch_size, heads, d_k]
torch.einsum('ibhd,jbhd->ijbh', query, key) # 理解為ibhd,jbhd->ibhj->ijbh
計算attention輸出:
\[\underset{\text { seq }}{\operatorname{softmax}}\left(\frac{Q K^{\top}}{\sqrt{d_{k}}}\right) V \]# attn [seq_len, seq_len, batch_size, heads]
# value [seq_len, batch_size, heads, d_k]
x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
# x [seq_len, batch_size, heads, d_k]