1. 程式人生 > 其它 >einsum函式介紹-張量常用操作

einsum函式介紹-張量常用操作

pytorch文件說明:\(torch.einsum(\)\(*equation*\)$, $$operands*$$)$ 使用基於愛因斯坦求和約定的符號,將輸入operands的元素沿指定的維數求和。einsum允許計算許多常見的多維線性代數陣列運算,方法是基於愛因斯坦求和約定以簡寫格式表示它們。主要是省略了求和號,總體思路是在箭頭左邊用一些下標標記輸入operands的每個維度,並在箭頭右邊定義哪些下標是輸出的一部分。通過將operands元素與下標不屬於輸出的維度的乘積求和來計算輸出。其方便之處在於可以直接通過求和公式寫出運算程式碼。**

兩個基本概念,自由索引(Free indices

)和求和索引(Summation indices):

  • 自由索引,出現在箭頭右邊的索引
  • 求和索引,只出現在箭頭左邊的索引,表示中間計算結果需要這個維度上求和之後才能得到輸出,

單運算元

獲取對角線元素diagonal

einsum 可以不做求和。舉個例子,獲取二維方陣的對角線元素,結果放入一維向量。

\[A_i = B_{ii} \]

上面,A 是一維向量,B 是二維方陣。使用 einsum 記法,可以寫作 ii->i

torch.einsum('ii->i', torch.randn(4, 4))

# 以下操作互相等價
a = torch.randn(4,4)
c = torch.einsum('ii->i', a)
c = torch.diagonal(a, 0)


跡trace

求解矩陣的跡(trace),即對角線元素的和。

\[t = \Sigma_{i=1}^{n} A_{ii} \]

t 是常量,A 是二維方陣。按照前面的做法,省略 ΣΣ,左右兩邊對調,省去矩陣和 t,剩下的就是ii->或省略箭頭ii

torch.einsum('ii', torch.randn(4, 4))

矩陣轉置

\[A_{ij} = B_{ji} \]

A 和 B 都是二維方陣。einsum 可以表達為 ij->ji

torch.einsum('ij -> ji',a)

pytorch 中,還支援省略前面的維度。比如,只轉置最後兩個維度,可以表達為 ...ij->...ji

。下面展示了一個含有四個二維矩陣的三維矩陣,轉置三維矩陣中的每個二維矩陣。

A = torch.randn(2, 3, 4, 5)
torch.einsum('...ij->...ji', A).shape
# torch.Size([2, 3, 5, 4])

# 等價操作
A.permute(0,1,3,2)
A.transpose(2,3)

求和

\[b=\sum_{i} \sum_{j} A_{i j}=A_{i j} \]
a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->', [a])
tensor(15.)

列求和:

\[b_{j}=\sum_{i} A_{i j}=A_{i j} \]
a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->j', [a])
tensor([ 3.,  5.,  7.])

# 等價操作
torch.sum(a, 0) # (dim引數0) means the dimension or dimensions to reduce.

雙運算元

矩陣乘法

\[A_{ij} = \Sigma_{k=1}^{n} B_{ik} C_{kj} \]

第一個學習的 einsum 表示式是,ik,kj->ij。前面提到過,愛因斯坦求和記法可以理解為懶人求和記法。將上述公式中的 ΣΣ 去掉,並且將左右兩邊對調一下,省去矩陣之後,剩下的就是 ik,kj->ij 了。

torch.einsum('ik,kj->ij', a, b) 

# 可用兩個矩陣測試以下矩陣乘法操作互相等價
a = torch.randn(2,3)
b = torch.randn(3,4)
c = torch.matmul(a,b)
c = torch.einsum('ik,kj->ij', a, b)
c = a.mm(b) 
c = torch.mm(a, b) 
c = a @ b

矩陣-向量相乘

\[c_{i}=\sum_{k} A_{i k} b_{k}=A_{i k} b_{k} \]
a = torch.arange(6).reshape(2, 3)
b = torch.arange(3)
torch.einsum('ik,k->i', [a, b])

tensor([  5.,  14.])

批量矩陣乘 batch matrix multiplication

\[C_{bik}=\sum_{k} A_{bij} B_{bjk}=A_{bij} B_{bjk} \]
>>> As = torch.randn(3,2,5)
>>> Bs = torch.randn(3,5,4)
>>> torch.einsum('bij,bjk->bik', As, Bs)
tensor([[[-1.0564, -1.5904,  3.2023,  3.1271],
        [-1.6706, -0.8097, -0.8025, -2.1183]],

        [[ 4.2239,  0.3107, -0.5756, -0.2354],
        [-1.4558, -0.3460,  1.5087, -0.8530]],

        [[ 2.8153,  1.8787, -4.3839, -1.2112],
        [ 0.3728, -2.1131,  0.0921,  0.8305]]])

# 等價操作
torch.bmm(As, Bs)

向量內積 dot

\[c=\sum_{i} a_{i} b_{i}=a_{i} b_{i} \]
a = torch.arange(3)
b = torch.arange(3,6)  # [3, 4, 5]
torch.einsum('i,i->', [a, b])
# tensor(14.)

# 等價操作
torch.dot(a, b)

矩陣內積 dot

\[c=\sum_{i} \sum_{j} A_{i j} B_{i j}=A_{i j} B_{i j} \]
a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
torch.einsum('ij,ij->', [a, b])
tensor(145.)

哈達瑪積

\[C_{i j}=A_{i j} B_{i j} \]
a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
torch.einsum('ij,ij->ij', [a, b])
tensor([[  0.,   7.,  16.],
        [ 27.,  40.,  55.]])

外積 outer

\[C_{i j}=a_{i} b_{j} \]
a = torch.arange(3)
b = torch.arange(3,7)
torch.einsum('i,j->ij', [a, b])

tensor([[  0.,   0.,   0.,   0.],
        [  3.,   4.,   5.,   6.],
        [  6.,   8.,  10.,  12.]])

einsum規則總結:

  • 表示式由輸入和輸出兩部分組成。例子,ij->ji
  • 輸出可以省略,箭頭也可以省略。輸入中僅出現一次的字元將按照字母序構成輸出。例子,ba 完整的表示式是 ba->ab
  • 輸入中多次出現的字元,將被用作求和。例子,kj,ji 完整的表示式是 kj,ji->ik,矩陣乘法再相乘。
  • 輸出可以指定,但是輸出中的每個字元必須在輸入中出現至少一次,輸出的每個字元在輸出中只能出現最多一次。例子,ab->aa 是非法的,ab->c 是非法的,ab->a 是合法的。
  • 省略符 ... 是用來跳過部分維度。例子,...ij,...jk 表示 batch 矩陣乘法。
  • 在輸出沒有指定的情況下,省略符優先順序高於普通字元。例子,b...a 完整的表示式是 b...a->...ab,可以將一個形狀為 (a,b,c) 的矩陣變為形狀為 (b,c,a) 的矩陣。
  • 允許多個矩陣輸入,表示式中使用逗號分開不同矩陣輸入的下標。例子,i,i,i 表示將三個一維向量按位相乘,並相加。
  • 除了箭頭,其他任何地方都可以加空格。例子,i j , j k -> ik 是合法的,ij,jk - > ik 是非法的。
  • 輸入的表示式,維度需要和輸入的矩陣對上,不能多也不能少。比如一個 shape 為 (4,3,3) 的矩陣,表示式 ab->a 是非法的,abc-> 是合法的。

實際使用

實現multi headed attention

https://nn.labml.ai/transformers/mha.html
如何優雅地實現多頭自注意力

計算注意力score:

\[Q K^{\top} or S_{i j b h}=\sum_{d} Q_{i b h d} K_{j b h d} \]
# q k v均為 [seq_len, batch_size, heads, d_k] 
torch.einsum('ibhd,jbhd->ijbh', query, key) # 理解為ibhd,jbhd->ibhj->ijbh

計算attention輸出:

\[\underset{\text { seq }}{\operatorname{softmax}}\left(\frac{Q K^{\top}}{\sqrt{d_{k}}}\right) V \]
# attn [seq_len, seq_len, batch_size, heads]
# value [seq_len, batch_size, heads, d_k] 

x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
# x [seq_len, batch_size, heads, d_k]