Python、numpy 與 axis
這次和大家分享的是 numpy 中的 axis 這個東西。當初學的時候也沒太在意,向來都是感覺差不多就直接過去了,沒有去深究背後的一些邏輯。前些天被問起的時候一時懵懂,查了下資料後發現還有點意思,於是就打算寫這麼一篇專欄來分享一下所得
要想學習 axis,首先要知道的就是 axis 的計數方式。我們在使用 numpy 的各種函式——比如說 np.sum——的時候,有一個引數就叫做 axis。那麼這個引數的意思是什麼呢?最直白地來說的話,就是“最外面的括號代表著 axis=0,依次往裡的括號對應的 axis 的計數就依次加 1”
舉個例子,現在我們有一個矩陣: ;在 Python,或說在 numpy 裡面,這個矩陣是這樣被表達出來的:x
不管畫風怎麼變,很醜這一點都無法改變啊……
所以相應的運算就是:
對應的程式碼實現和執行結果如下:
可以看到,貌似出來的結果比我們推導的結果的括號要少一些。這是因為諸如 np.sum 這種函式中有一個引數叫 keepdims,它的預設值是 False,此時它會把多餘的括號給刪掉。假如我們把它設為 True 的話,就可以得到和推導中一致的結果了:
下面來看一個更“高維”一點的例子:
對應的程式碼實現和執行結果如下:
以及
可以看到結果和我們推導的確實一樣
現在我們知道哪個 axis 對應於陣列中的哪些元素了,接下來還需要知道的就是 transpose 這個函式到底在背後幹了什麼。從紙面上來看,如果一個高維陣列 x 的 shape 是 (2, 3, 4),那麼 transpose 的作用就是把這個 shape 中各個數的順序改一改。比如說:
但是 transpose 返回的結果究竟是如何得到的,可能就比較難理解了。幸運的是,這個回答-StackOverflow非常好地闡明瞭這背後的原理。為了方便觀眾老爺們,我在這裡就當一個搬運 and 潤色工
首先是對這個 shape 的理解。直觀地說,shape 中的各個數就是對應 axis 的元素個數
字比畫還醜呢……
如果我們換一種思路的話,以 axis=0 為例,由於我們現在整個數組裡面一共有 24 個數,而 axis=0 只有兩個元素,所以可以理解為在 axis=0 這個 axis 上,每隔 24 / 2 = 12 個數就跳一下。比如說上面這個圖中就可以看出,兩個橙色矩陣對應的數之間差的都是 12
類似的,由於一個橙色矩陣中只有 24 / 2 = 12 個數,所以我們可以理解為在 axis=1 這個 axis 上,每隔 12 / 3 = 4 個數就跳一下。表現在圖中,就是同一個橙色矩陣的兩個相鄰的藍色向量對應的數之間差的都是 4
再次類似的,由於一個藍色向量中只有 12 / 3 = 4 個數,我們可以理解為在 axis=2 這個 axis 上,每隔 4 / 4 = 1 個數就跳一下。表現在圖中……觀眾老爺們想必也知道是怎樣的了 ( σ’ω’)σ
所以我們現在可以定義一個新的東西,比如說叫做 strides 吧,它記錄著每個 axis 上跳過的數。比如說上圖對應的三維陣列,它的 strides 就是 (12, 4, 1)
那麼接下來激動人心的時刻到了:transpose 的本質,其實就是對 strides 中各個數的順序進行調換。舉個例子:
在 transpose(1, 0, 2) 後,相應的 strides 會變成 (4, 12, 1)。而從上圖可以看出,transpose 的結果確實滿足:
- axis=0 的 axis 上,每隔 4 個數跳一下
- axis=1 的 axis 上,每隔 12 個數跳一下
- axis=2 的 axis 上,每隔 1 個數跳一下
至此,transpose 背後的邏輯就理順啦!撒花!★,°:.☆\( ̄▽ ̄)/$:.°★ 。
選自《 Python 與 機器學習》
原文連結:Python、numpy與axis