1. 程式人生 > >numpy的廣播(broadcasting)機制

numpy的廣播(broadcasting)機制

什麼是廣播:

Numpy的廣播既是在2個不同的矩陣運算過程中,Numpy將較小的陣列拉伸成較大陣列的形狀(shape),較小維度的陣列會被廣播到另一個數組的相應維度上去,這就使得兩個不同形狀(但也有基本要求,不是任何維度都可以廣播)的陣列也可以進行運算。

如:

import numpy as np

a = np.array([1.0, 2.0, 3.0])
b = 2.0
print(a, a.shape)
print(b)
print(a * b)

執行結果如下:

[1. 2. 3.] (3,)
2.0
[2. 4. 6.]

廣播原則:

如果滿足以下規則,可以進行廣播:

維度較小的陣列會在前面追加一個長度為 1 的維度;

輸出陣列的每個維度的大小是輸入陣列該維度大小的最大值;

如果輸入在每個維度中的大小與輸出大小匹配,或其值正好為 1,則在計算中可它;

如果輸入的某個維度大小為 1,則該維度中的第一個資料元素將用於該維度的所有計算。

按照上述規則修改陣列後,如果還滿足下面的要求,那麼稱為輸入陣列是可廣播的:

輸入陣列和輸出陣列擁有相同形狀;

輸入陣列和輸出陣列的維數相同,且各個維上的數值相同(數值就是我們通常說的m行n列的m和n),或者數值不同但數值為 1;

對上面的規則的通俗說法:

首先檢查兩個變數維數是否相同,若不同,對維數少的補一。維數不是指n行d列中的n和d的值,對於一般的矩陣維數就是2。

如一個兩維的矩陣(n,d)和一個一維的陣列(m,)相乘,補一操作就是將那個一維的陣列變為(1,m),補一總是在shape陣列的開始位置補一;

輸出陣列是輸入陣列各維度(軸)的最大值。如(2,3)和(3,)相乘,首先做第一步的維度調整,修正為(2,3)和(1,3)。那麼第一維最大是在2和1中選2,第二維最大值是在3和3中選3。那麼輸出陣列維度是(2,3);

檢查輸入陣列各維的數和輸出陣列各維的關係,要麼相等,要麼為一(注意是每個維度都要滿足要麼相等要麼為1)。如上面的例子中輸入陣列(2,3)和輸出陣列(2,3)在各維上都是相等的,而(1,3)和(2,3)雖然第一維不相等,但是卻等於1,這也是可以計算的。再舉一個反例(2,4)(3,),先做維度調整,變為(2,4)和(1,3),在計算輸出陣列的維度為(2,4),最後做第三步輸入陣列(2,4)和輸出陣列(2,4)相等,但輸入陣列(1,3)和輸出陣列(2,4)的第二維不相等也不等於1,所以計算失敗。

經過上面的步驟,可以認為輸入的兩個陣列各維的數要麼相等要麼等於1。對於等於1的維度開始複製增補。如(1,3)和(3,1)的輸出是(3,3)。對於(1,3)要對每一行復制,最終變為(3,3),例如[[2,3,4]]變為[[2,3,4],[2,3,4],[2,3,4]]。對於(3,1)要對每一列複製,最終變為(3,3),例如[[2],[3],[4]]]變為[[2,2,2],[3,3,3],[4,4,4]]。總之哪一維為1就對哪一維複製增補,直到輸入陣列的形狀(shape)和輸出陣列的形狀相同。完成了這一步,兩個陣列的shape就完全相同,這時候就可以執行普通運算了。

如:

import numpy as np

x = np.arange(4)
xx = x.reshape(4, 1)
y = np.ones(5)
z = np.ones((3, 4))

print(x, x.shape)
print(y, y.shape)
# print(x + y)
# 如果執行print(x + y)會出現下面的錯誤提示,因為x是1維陣列,y也是一維陣列,但x有4個元素,而y有5個元素,不滿足廣播機制的要求
# <type 'exceptions.ValueError'>: shape mismatch: objects cannot be broadcast to a single shape
# xx變成了4行一列的矩陣,每個元素是原來一維陣列的單個元素
print(xx, xx.shape)
print(y, y.shape)
print(xx + y, (xx + y).shape)
# xx現在是4行1列,y是一維陣列有5個元素,即(4,1)和(5,)相乘,根據廣播原則,y要在所有維度前追加一個長度為1的維度,變成(1,5)
# 那麼輸出陣列就應該是兩個輸入的陣列中每個維度上的最大值,即(4,1)和(1,5)同樣維度比較,取最大值,應該是(4,5)
# 檢查輸入陣列各維的數值和輸出陣列各維的數值的關係,要麼相等要麼是1
# 輸入陣列(4,1)和(1,5),輸出陣列是(4,5),那麼對(4,1),其第一個維度數值與(4,5)相等,第二個維度數值等於1,滿足,對(1,5)也這麼檢查,也滿足
# 對輸入陣列中等於1的維度開始複製增補
# 對(4,1)按列增補到(4,5),對(1,5)按行增補到(4,5),然後相加

print(x, x.shape)
print(z, z.shape)
print(x + z, (x + z).shape)
# x是(4,),z是(3,4),那麼在x前追加一個長度為1的維度,變成(1,4)
# 輸出陣列就是(1,4)和(3,4)上每個維度中數值的最大值,應該是(3,4)
# 檢查輸入陣列各維的數值和輸出陣列各維的數值的關係,要麼相等要麼是1
# (1,4)和(3,4)與輸出陣列(3,4)滿足上面的要求
# 對(1,4)按行增補到(3,4),然後相加

上面程式碼中包含兩個例子:

xx+y:

xx是(4,1)矩陣,值為

[[0]  [1]  [2]  [3]] (4, 1)

y是一維陣列(5,),值為

[1. 1. 1. 1. 1.] (5,)

根據廣播原則,y要在所有維度前追加一個長度為1的維度,變成(1,5),輸出陣列就應該是兩個輸入的陣列中每個維度上的最大值:

即(4,1)和(1,5)同樣維度比較,取最大值,應該是(4,5)

檢查輸入陣列各維的數值和輸出陣列各維的數值的關係,要麼相等要麼是1:

輸入陣列(4,1)和(1,5),輸出陣列是(4,5),那麼對(4,1),其第一個維度數值與(4,5)相等,第二個維度數值等於1,滿足,對(1,5)也這麼檢查,也滿足。

對輸入陣列中等於1的維度開始複製增補: # 對(4,1)按列增補到(4,5),對(1,5)按行增補到(4,5),然後相加,就可以得到結果

[[1. 1. 1. 1. 1.]  [2. 2. 2. 2. 2.]  [3. 3. 3. 3. 3.]  [4. 4. 4. 4. 4.]] (4, 5)

x+z:

x是(4,),z是(3,4),那麼在x前追加一個長度為1的維度,變成(1,4);

輸出陣列就是(1,4)和(3,4)上每個維度中數值的最大值,應該是(3,4);

檢查輸入陣列各維的數值和輸出陣列各維的數值的關係,要麼相等要麼是1。(1,4)和(3,4)與輸出陣列(3,4)滿足上面的要求;

對(1,4)按行增補到(3,4),然後相加,結果為

[[1. 2. 3. 4.]  [1. 2. 3. 4.]  [1. 2. 3. 4.]] (3, 4)

執行結果如下:

[0 1 2 3] (4,)
[1. 1. 1. 1. 1.] (5,)
[[0]
 [1]
 [2]
 [3]] (4, 1)
[1. 1. 1. 1. 1.] (5,)
[[1. 1. 1. 1. 1.]
 [2. 2. 2. 2. 2.]
 [3. 3. 3. 3. 3.]
 [4. 4. 4. 4. 4.]] (4, 5)
[0 1 2 3] (4,)
[[1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]] (3, 4)
[[1. 2. 3. 4.]
 [1. 2. 3. 4.]
 [1. 2. 3. 4.]] (3, 4)

Process finished with exit code 0