1. 程式人生 > >基於numpy.einsum的張量網路計算

基於numpy.einsum的張量網路計算

# 張量與張量網路 張量(Tensor)可以理解為廣義的矩陣,其主要特點在於將數字化的矩陣用圖形化的方式來表示,這就使得我們可以將一個大型的矩陣運算抽象化成一個具有良好性質的張量圖。由一個個張量所共同構成的運算網路圖,就稱為張量網路(Tensor Network)。讓我們用幾個常用的圖來看看張量網路大概長什麼樣子(下圖轉載自*參考連結1*): ![](https://img2020.cnblogs.com/blog/2277440/202101/2277440-20210118182350037-1230304669.png) 上面這個圖從左到右分別表示:一階張量、二階張量以及三階張量,我們可以看出,一個張量的階數在影象化的表示中被抽象稱為了張量的`腿`的數量,而中間的方形或者圓形則表示張量本身。實際上,一階張量代表的一個向量,比如我們平時用python所定義的一個數組變數: ```python x = [1, 0] y = [0, 1, 0] z = [1, 2, 3, 4] ``` 那麼這裡的`x,y,z`都是一階的張量。而二階張量所表示的含義是一個二維的矩陣,如我們常見的python多維陣列: ```python M = [[1, -1], [-1, 1]] N = [[1, 3], [2, 4], [5, 6]] ``` 這裡定義的`M, N`都是二階的張量。通過觀察這些示例中的一階和二階的張量我們可以得到一個規律:能夠用形如`var[i]`的形式讀取和遍歷`var`中的標量元素的就可以稱之為一階張量,能夠用形如`var[i][j]`的形式讀取和遍歷`var`中的標量元素的可以稱之為二階張量。顯然,屬於幾階的張量,跟張量內所包含的元素個數是無關的。那麼根據這個客觀規律,我們可以再推廣到零階張量和更加高階的張量: ```python pi = 3.14 P = [[[1]]] Q = [[[1, 1, 1], [1, 1, 1], [1, 1, 1]]] ``` 在上述的python變數定義中,`pi`就是一個零階的張量,零階張量實際上就等同於一個標量,而`P, Q`都是三階的張量。需要注意的是,雖然張量`P`只有一個元素,但是如果我們需要讀取這個標量元素,我們必須使用如下的python指令來執行: ```python print (P[0][0][0]) ``` 因此`P`也是一個有`三條腿`的張量。在使用張量的形式來表示單個矩陣的同時,我們需要考慮如果有多個矩陣的乘法運算,我們該如何表示?我們先以兩種形式的python矩陣運算來說明張量計算的表示方法: ```python import numpy as np M = np.random.rand(2, 2) v = np.random.rand(2) w = np.dot(M, v) print (M) print (v) print (w) ``` 這一串python程式碼表示的計算過程為:$w_{2\times1}=M_{2\times2}\cdot v_{2\times1}$,為了不失廣泛有效性,這裡使用隨機的張量來進行計算,這裡的`M`表示二階張量,`v,w`表示一階張量。如果從矩陣運算的角度來理解,實際上就是一個$2\times2$的矩陣乘以一個$2\times1$的向量,並且得到了一個新的$2\times1$的向量。計算所得到的結果如下所示: ```bash [[0.09660039 0.55849787] [0.93007524 0.32329559]] [0.74966152 0.59573188] [0.40513259 0.88983912] ``` 同時我們也考慮下另外一種張量運算的場景,一個高階的張量與另外一個高階的張量進行運算: ```python import numpy as np A = np.random.rand(1, 2, 2, 2) B = np.random.rand(2, 2, 2) C = np.einsum('ijkl,klm->ijm', A, B) print ('A:', A) print ('B:', B) print ('C:', C) ``` 這一串python程式碼表示的計算過程為:$C_{1\times2\times2}=A_{1\times2\times2\times2}\cdot B_{2\times2\times2}$,由於這裡的多維張量運算已經不能使用普通的`numpy.dot`來處理,因此我們還是適用了專業的張量計算函式`numpy.einsum`來進行處理。 以上的兩個案例,從張量理論的角度來理解,相當於分別將張量`w`和張量`C`表示成了多個張量組合運算的結果。由多個張量構成的組合運算,我們可以使用張量網路來表示: ![](https://img2020.cnblogs.com/blog/2277440/202101/2277440-20210118182407942-688164475.png) 上圖所示的$(a)$和$(b)$就分別表示張量`w`和張量`C`的張量網路圖。而這個將張量網路的所有張量進行計算,最終得到一個或一系列的新的張量的矩陣乘加過程,我們也稱之為`張量縮並`,英文叫Tensor Contraction,注:上圖轉載自*參考連結1*。 # 張量縮並順序與計算複雜性 不失廣泛有效性的,我們可以以兩個張量的縮併案例來分析張量縮並的複雜性,兩個張量縮並的計算複雜性主要取決於這兩個張量總的`腿`的數量,如果兩個張量之間有共用的`腿`,則計為1。以上圖中的$(a)$為例,一個$2\times2$的矩陣乘以一個$2\times1$的向量,一共需要4次乘法運算,而由`M`和`v`所構成的張量網路一共有`2條腿`,那麼4次的乘法預算符合$O(d^2)$的計算複雜性,這裡的d指的是指定的`腿`的維度,常用的是2。相關的複雜性除了理論推導,用`numpy.einsum`的功能模組也可以實現程式判斷: ```python import numpy as np M = np.random.rand(2, 2) v = np.random.rand(2) path_info = np.einsum_path('ij,j->i', M, v, optimize='greedy') print(path_info[0]) print(path_info[1]) ``` 輸出結果如下: ```bash ['einsum_path', (0, 1)] Complete contraction: ij,j->i Naive scaling: 2 Optimized scaling: 2 Naive FLOP count: 8.000e+00 Optimized FLOP count: 9.000e+00 Theoretical speedup: 0.889 Largest intermediate: 2.000e+00 elements -------------------------------------------------------------------------- scaling current remaining -------------------------------------------------------------------------- 2 j,ij->i i->i ``` 這裡的scaling就是上面提到的複雜性$O(d^2)$中的$2$。同樣的如果以上圖中的$(b)$為例,我們可以通過理論推匯出其計算複雜性為$O(d^5)$,即理論的scaling應該是5,下面也通過程式實現來給出定論: ```python import numpy as np A = np.random.rand(1, 2, 2, 2) B = np.random.rand(2, 2, 2) path_info = np.einsum_path('ijkl,klm->ijm', A, B, optimize='greedy') print(path_info[0]) print(path_info[1]) ``` 以上程式的執行結果如下: ```bash ['einsum_path', (0, 1)] Complete contraction: ijkl,klm->ijm Naive scaling: 5 Optimized scaling: 5 Naive FLOP count: 3.200e+01 Optimized FLOP count: 3.300e+01 Theoretical speedup: 0.970 Largest intermediate: 4.000e+00 elements -------------------------------------------------------------------------- scaling current remaining -------------------------------------------------------------------------- 5 klm,ijkl->ijm ijm->ijm ``` 這裡需要我們注意的一點是,如果有兩條邊同時連線,那麼計算scaling的時候也是作為兩條邊來計算的,而不是合併為一條邊之後再計算scaling。 由於上面所提到的兩個例子,其實都只涉及到兩個張量之間的預算,當多個張量一同進行運算時,就會引入一個新的參量:`縮並順序`,在張量網路的實際應用場景中,縮並順序會極大程度上的影響張量網路計算的速度。首先,讓我們用一個例子來分析,為什麼不同的縮並順序會對張量網路計算的效能產生影響:給定四個張量為: $a_{ijk},b_{jlmn},c_{klo}和d_{mo}$ 。如果先縮並 bc ,則對應的計算複雜度的scaling為`6`。若按照縮並順序:cd->c,bc->b,ab->a,對應的計算複雜度scaling為`5` 。也就是說,從複雜度的角度來說,這裡選出了一條複雜度較低的縮並路線,這一條複雜性scaling較好的縮並順序也是由`numpy.einsum`的貪心演算法找出來的: ```python import numpy as np np.random.seed(123) a = np.random.rand(2, 2, 2) b = np.random.rand(2, 2, 2, 2) c = np.random.rand(2, 2, 2) d = np.random.rand(2, 2) path_info = np.einsum_path('ijk,jlmn,klo,mo->in', a, b, c, d, optimize='greedy') print(path_info[0]) print(path_info[1]) ``` 執行的結果如下所示: ```bash ['einsum_path', (2, 3), (1, 2), (0, 1)] Complete contraction: ijk,jlmn,klo,mo->in Naive scaling: 7 Optimized scaling: 5 Naive FLOP count: 5.120e+02 Optimized FLOP count: 1.290e+02 Theoretical speedup: 3.969 Largest intermediate: 8.000e+00 elements -------------------------------------------------------------------------- scaling current remaining -------------------------------------------------------------------------- 4 mo,klo->klm ijk,jlmn,klm->in 5 klm,jlmn->jkn ijk,jkn->in 4 jkn,ijk->in in->in ``` # 張量分割對張量網路縮並複雜性的影響 在前面的章節中我們討論了將一個張量網路縮併為一個張量的場景下,如何降低其複雜性scaling。其中重點說明了,在特定的縮並順序下,可以極大程度上的優化張量縮並的效能。這裡我們討論一種在量子計算中常用的技巧:`張量的分割`。簡單的來說,就是前面提到的張量縮並的逆向過程,既然可以將兩個張量縮併成一個,那就有可能將一個張量分割成兩個張量。 那麼為什麼需要執行張量分割的操作呢?我們可以直接通過一個案例來說明: ```python import numpy as np np.random.seed(123) a = np.random.rand(2) b = np.random.rand(2) c = np.random.rand(2, 2, 2, 2) d = np.random.rand(2) e = np.random.rand(2) path_info = np.einsum_path('i,j,ijkl,k,l', a, b, c, d, e, optimize='greedy') print(path_info[0]) print(path_info[1]) ``` 這裡給定了5個張量,其中張量`c`有四條腿,那麼在這個場景下不論以什麼樣的順序進行縮並,得到的複雜性的scaling都必然是4,以下是`numpy.einsum`給出的結果: ```bash ['einsum_path', (0, 2), (0, 3), (0, 2), (0, 1)] Complete contraction: i,j,ijkl,k,l-> Naive scaling: 4 Optimized scaling: 4 Naive FLOP count: 8.000e+01 Optimized FLOP count: 6.100e+01 Theoretical speedup: 1.311 Largest intermediate: 8.000e+00 elements -------------------------------------------------------------------------- scaling current remaining -------------------------------------------------------------------------- 4 ijkl,i->jkl j,k,l,jkl-> 3 jkl,j->kl k,l,kl-> 2 kl,k->l l,l-> 1 l,l-> -> ``` 但是,如果我們考慮先將這個腿最多的張量做一個分割,使其變為兩個三條腿的張量,並且這兩個張量之間通過一條邊進行連線,程式碼示例如下: ```python import numpy as np np.random.seed(123) a = np.random.rand(2) b = np.random.rand(2) c = np.random.rand(2, 2, 2) d = np.random.rand(2, 2, 2) e = np.random.rand(2) f = np.random.rand(2) path_info = np.einsum_path('i,j,imk,jml,k,l', a, b, c, d, e, f, optimize='greedy') print(path_info[0]) print(path_info[1]) ``` 讓我們先看看`numpy.einsum`是否會給出一樣的縮並順序呢? ```bash ['einsum_path', (0, 2), (0, 1), (0, 2), (0, 1), (0, 1)] Complete contraction: i,j,imk,jml,k,l-> Naive scaling: 5 Optimized scaling: 3 Naive FLOP count: 1.920e+02 Optimized FLOP count: 5.300e+01 Theoretical speedup: 3.623 Largest intermediate: 4.000e+00 elements -------------------------------------------------------------------------- scaling current remaining -------------------------------------------------------------------------- 3 imk,i->km j,jml,k,l,km-> 3 jml,j->lm k,l,km,lm-> 2 km,k->m l,lm,m-> 2 lm,l->m m,m-> 1 m,m-> -> ``` 我們驚訝的發現,這個給定的scaling較低的縮並順序並沒有一開始就縮並`m`這條邊,如果先縮並了`m`這條邊,那麼得到的結果應該跟上面未作分割的順序和scaling是一樣的。另言之,我們通過這種張量切割的方案,實際上大大降低了這個張量網路的縮並所需時間。這裡的複雜性scaling每降低1,就意味著需要執行的乘加次數有可能減少到優化前的$\frac{1}{d}$. # 補充測試 針對於上述章節提到的張量分割的方案,我們這裡再多一組更加複雜一些的張量網路的測試: ```python import networkx as nx graph = nx.Graph() graph.add_nodes_from([1,2,3,4,5,6,7,8,9]) graph.add_edges_from([(1,4),(2,4),(3,5),(4,5),(4,6),(5,6),(6,7),(5,8),(6,9)]) nx.draw_networkx(graph) ``` ![](https://img2020.cnblogs.com/blog/2277440/202101/2277440-20210118231632649-1417785826.png) 考慮上圖這樣的一個張量網路,我們也先將其中三個四條腿的張量進行分割,分割後的張量網路變為如下所示的拓撲結構: ```python import networkx as nx graph = nx.Graph() graph.add_nodes_from([1,2,3,4,5,6,7,8,9,10,11,12]) graph.add_edges_from([(1,4),(5,4),(2,5),(4,5),(4,8),(5,6),(6,7),(3,7),(7,9),(6,11),(8,10),(8,9),(9,12)]) nx.draw_networkx(graph) ``` ![](https://img2020.cnblogs.com/blog/2277440/202101/2277440-20210118231759665-15211793.png) 然後再次使用`numpy.einsum`來進行驗證。首先是張量分割前的張量網路縮並: ```python import numpy as np np.random.seed(123) a = np.random.rand(2) b = np.random.rand(2) c = np.random.rand(2) d = np.random.rand(2, 2, 2, 2) e = np.random.rand(2, 2, 2, 2) f = np.random.rand(2, 2, 2, 2) g = np.random.rand(2) h = np.random.rand(2) i = np.random.rand(2) path_info = np.einsum_path('i,j,k,ijlm,mnko,lpoq,p,n,q', a, b, c, d, e, f, g, h, i, optimize='greedy') print(path_info[0]) print(path_info[1]) ``` 執行結果如下: ```bash ['einsum_path', (0, 3), (1, 2), (1, 2), (0, 3), (0, 2), (0, 1), (0, 1), (0, 1)] Complete contraction: i,j,k,ijlm,mnko,lpoq,p,n,q-> Naive scaling: 9 Optimized scaling: 4 Naive FLOP count: 4.608e+03 Optimized FLOP count: 1.690e+02 Theoretical speedup: 27.266 Largest intermediate: 8.000e+00 elements -------------------------------------------------------------------------- scaling current remaining -------------------------------------------------------------------------- 4 ijlm,i->jlm j,k,mnko,lpoq,p,n,q,jlm-> 4 mnko,k->mno j,lpoq,p,n,q,jlm,mno-> 4 p,lpoq->loq j,n,q,jlm,mno,loq-> 3 jlm,j->lm n,q,mno,loq,lm-> 3 mno,n->mo q,loq,lm,mo-> 3 loq,q->lo lm,mo,lo-> 3 mo,lm->lo lo,lo-> 2 lo,lo-> -> ``` 我們可以看到未進行張量分割前的複雜性scaling為4,再讓我們看下張量分割之後的張量網路縮並情況: ```python import numpy as np np.random.seed(123) a = np.random.rand(2) b = np.random.rand(2) c = np.random.rand(2) d = np.random.rand(2, 2, 2) e = np.random.rand(2, 2, 2) f = np.random.rand(2, 2, 2) g = np.random.rand(2, 2, 2) h = np.random.rand(2, 2, 2) i = np.random.rand(2, 2, 2) j = np.random.rand(2) k = np.random.rand(2) l = np.random.rand(2) path_info = np.einsum_path('i,j,k,iml,jmn,nop,kpq,lrs,sqt,r,o,t', a, b, c, d, e, f, g, h, i, j, k, l, optimize='greedy') print(path_info[0]) print(path_info[1]) ``` 執行結果如下: ```bash ['einsum_path', (0, 3), (0, 2), (0, 2), (0, 4), (0, 2), (0, 1), (0, 1), (0, 1), (0, 1), (0, 1), (0, 1)] Complete contraction: i,j,k,iml,jmn,nop,kpq,lrs,sqt,r,o,t-> Naive scaling: 12 Optimized scaling: 3 Naive FLOP count: 4.915e+04 Optimized FLOP count: 1.690e+02 Theoretical speedup: 290.840 Largest intermediate: 4.000e+00 elements -------------------------------------------------------------------------- scaling current remaining -------------------------------------------------------------------------- 3 iml,i->lm j,k,jmn,nop,kpq,lrs,sqt,r,o,t,lm-> 3 jmn,j->mn k,nop,kpq,lrs,sqt,r,o,t,lm,mn-> 3 kpq,k->pq nop,lrs,sqt,r,o,t,lm,mn,pq-> 3 o,nop->np lrs,sqt,r,t,lm,mn,pq,np-> 3 r,lrs->ls sqt,t,lm,mn,pq,np,ls-> 3 t,sqt->qs lm,mn,pq,np,ls,qs-> 3 mn,lm->ln pq,np,ls,qs,ln-> 3 np,pq->nq ls,qs,ln,nq-> 3 qs,ls->lq ln,nq,lq-> 3 nq,ln->lq lq,lq-> 2 lq,lq-> -> ``` 我們再次發現,張量縮並的複雜性scaling被優化到了3。假如是我們常見的$d=2$的張量網路,那麼在進行張量分割之後,類似於上面這個案例的,張量縮並的時間可以加速1倍甚至更多。 # 總結概要 本文主要介紹了張量網路的基本定義及其縮並複雜性scaling的含義,其中利用`numpy.einsum`這個高階輪子進行了用例的演示,並且額外的介紹了張量分割在張量網路縮並實際應用場景中的重要地位。通常我們會配合GPU來進行張量網路的縮並,那麼這個時候縮並複雜性的scaling影響的就不僅僅是縮並的速度,因為GPU本身的記憶體是比較侷限的,而不斷的IO會進一步拉長張量網路縮並所需要的時間。而如果能夠有方案將一個給定的張量網路的複雜性scaling降低到GPU自身記憶體可以儲存的水平,那將極大程度上的降低使用張量網路技術求解實際問題的時間。 # 參考連結 1. 什麼是張量網路(tensor network)? - 何史提的回答 - 知乎 https://www.zhihu.com/question/54786880/answer/147099121 2. Michael Streif1, Martin Leib,"*Training the Quantum Approximate Optimization Algorithm without access to a Quantum Processing Unit*", 2019, arXiv:19