1. 程式人生 > >Python中numpy和pandas中axis淺談

Python中numpy和pandas中axis淺談

在Python中軸是比較難懂概念,先從座標軸說起。

n 維空間裡有 n 個座標軸,並且座標軸互相垂直,每一個點相對於一條座標軸都有唯一的一個座標值。對同一條座標軸來說,座標值相同的點在同一個 n-1 維的“平面”上。任意取一個“平面”,我們就能定義“同一個座標軸上的點”,這些點在“平面”上的投影相同,同一個座標軸上的點組成的線是與座標軸平行的。而所謂的延軸計算實際上是降維的過程,同一個座標軸上的點合併成一個點,這樣n維空間就變成了 n-1 維空間。

具體到 numpy 中的多維陣列來說,軸即是元素座標的索引。比如,第0軸即是第1個索引,延0軸計算就是去掉座標中的第一個索引。過程就是

  1. 遍歷其他索引的所有可能組合
  2. 取出一個組合,保持值不變,遍歷第一個索引所有可能值
  3. 根據索引可以獲得了同一個軸上的所有元素
  4. 對他們進行計算得到最後的元素
  5. 所有組合的最後結果組到一起就是最後的 n-1 維陣列

沿軸計算過程,可以當做沿哪一個方向進行投影再進行計算。所以如果一個多維陣列的 shape 是 (a1, a2, a3, a4), 那麼延軸0計算最後的陣列shape 是 (a2, a3, a4), 延軸1計算最後的陣列shape是 (a1, a3, a4)

>>> a = array([[[1,1],[2,1],[3,1]],[[4,1],[5,1],[6,1]],[[7,1],[8,1],[9,1]]])>>>
a.shape(3, 3, 2)>>> aarray([[[1, 1], [2, 1], [3, 1]], [[4, 1], [5, 1], [6, 1]], [[7, 1], [8, 1], [9, 1]]])>>> sum(a, axis=0)array([[12, 3], [15, 3], [18, 3]])>>> sum(a, axis=1)array([[ 6, 3], [15, 3
], [24, 3]])>>> sum(a, axis=2)array([[ 2, 3, 4], [ 5, 6, 7], [ 8, 9, 10]])