tensorflow中axis理解非常重要
axis的作用即如何理解
numpy是python進行科學計算必不可少的模組,隨著深度學習越來越火,numpy也越來越流行。瞭解numpy的人知道,在numpy中,有很多的函式都涉及到axis,很多函式根據axis的取值不同,得到的結果也完全不同。可以說,axis讓numpy的多維陣列變的更加靈活,但也讓numpy變得越發難以理解。這裡通過詳細的例子來學習下,axis到底是什麼,它在numpy中的作用到底如何。
為什麼會有axis這個東西,原因很簡單:numpy是針對矩陣或者多為陣列進行運算的,而在多維陣列中,對資料的操作有太多的可能,我們先來看一個例子。比如我們有一個二維陣列:
>>> import numpy as np
>>> data = np.array([
... [1,2,1],
... [0,3,1],
... [2,1,4],
... [1,3,1]])
這個陣列代表了樣本資料的特徵,其中每一行代表一個樣本的三個特徵,每一列是不同樣本的特徵。如果在分析樣本的過程中需要對每個樣本的三個特徵求和,該如何處理?簡單:
>>> np.sum(data, axis=1)
array([4, 4, 7, 5])
那如果想求每種特徵的最小值,該如何處理?也簡單:
>>> np.min(data, axis=0)
array([0, 1, 1])
又如果想得知所有樣本所有特徵的平均值呢?還是很簡單:
>>> np.average(data )
1.6666666666666667
由此可以看出,通過不同的axis,numpy會沿著不同的方向進行操作:如果不設定,那麼對所有的元素操作;如果axis=0,則沿著縱軸進行操作;axis=1,則沿著橫軸進行操作。但這只是簡單的二位陣列,如果是多維的呢?可以總結為一句話:設axis=i,則numpy沿著第i個下標變化的放下進行操作。例如剛剛的例子,可以將表示為:data =[[a00, a01],[a10,a11]],所以axis=0時,沿著第0個下標變化的方向進行操作,也就是a00->a10, a01->a11,也就是縱座標的方向,axis=1時也類似。下面我們舉一個四維的求sum的例子來驗證一下:
>>> data = np.random.randint(0, 5, [4,3,2,3])
>>> data
array([[[[4, 1, 0],
[4, 3, 0]],
[[1, 2, 4],
[2, 2, 3]],
[[4, 3, 3],
[4, 2, 3]]],
[[[4, 0, 1],
[1, 1, 1]],
[[0, 1, 0],
[0, 4, 1]],
[[1, 3, 0],
[0, 3, 0]]],
[[[3, 3, 4],
[0, 1, 0]],
[[1, 2, 3],
[4, 0, 4]],
[[1, 4, 1],
[1, 3, 2]]],
[[[0, 1, 1],
[2, 4, 3]],
[[4, 1, 4],
[1, 4, 1]],
[[0, 1, 0],
[2, 4, 3]]]])
當axis=0時,numpy驗證第0維的方向來求和,也就是第一個元素值=a0000+a1000+a2000+a3000=11,第二個元素=a0001+a1001+a2001+a3001=5,同理可得最後的結果如下:
>>> data.sum(axis=0)
array([[[11, 5, 6],
[ 7, 9, 4]],
[[ 6, 6, 11],
[ 7, 10, 9]],
[[ 6, 11, 4],
[ 7, 12, 8]]])
當axis=3時,numpy驗證第3維的方向來求和,也就是第一個元素值=a0000+a0001+a0002=5,第二個元素=a0010+a0011+a0012=7,同理可得最後的結果如下:
>>> data.sum(axis=3)
array([[[ 5, 7],
[ 7, 7],
[10, 9]],
[[ 5, 3],
[ 1, 5],
[ 4, 3]],
[[10, 1],
[ 6, 8],
[ 6, 6]],
[[ 2, 9],
[ 9, 6],
[ 1, 9]]])
使用axis的相關函式
在numpy中,使用的axis的地方非常多,處理上文已經提到的average、max、min、sum,比較常見的還有sort和prod,下面分別舉幾個例子看一下:
sort
>>> data = np.random.randint(0, 5, [3,2,3])
>>> data
array([[[4, 2, 0],
[0, 0, 4]],
[[2, 1, 1],
[1, 0, 2]],
[[3, 0, 4],
[0, 1, 3]]])
>>> np.sort(data) ## 預設對最大的axis進行排序,這裡即是axis=2
array([[[0, 2, 4],
[0, 0, 4]],
[[1, 1, 2],
[0, 1, 2]],
[[0, 3, 4],
[0, 1, 3]]])
>>> np.sort(data, axis=0) # 沿著第0維進行排序,原先的a000->a100->a200轉變為a100->a200->a000
array([[[2, 0, 0],
[0, 0, 2]],
[[3, 1, 1],
[0, 0, 3]],
[[4, 2, 4],
[1, 1, 4]]])
>>> np.sort(data, axis=1) # 沿著第1維進行排序
array([[[0, 0, 0],
[4, 2, 4]],
[[1, 0, 1],
[2, 1, 2]],
[[0, 0, 3],
[3, 1, 4]]])
>>> np.sort(data, axis=2) # 沿著第2維進行排序
array([[[0, 2, 4],
[0, 0, 4]],
[[1, 1, 2],
[0, 1, 2]],
[[0, 3, 4],
[0, 1, 3]]])
>>> np.sort(data, axis=None) # 對全部資料進行排序
array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4])
prod(即product,乘積)
>>> np.prod([[1.,2.],[3.,4.]])
24.0
>>> np.prod([[1.,2.],[3.,4.]], axis=1)
array([ 2., 12.])
>>> np.prod([[1.,2.],[3.,4.]], axis=0)
array([ 3., 8.])
相信通過上面的講解與例子,你應該對axis有了比較清楚的瞭解。個人認為,如果沒有理解axis的真正含義,很難熟悉的運用numpy進行資料處理