探祕TensorFlow 和 NumPy 的 Broadcasting 機制
在使用Tensorflow的過程中,我們經常遇到陣列形狀不同的情況,但有時候發現二者還能進行加減乘除的運算,在這背後,其實是Tensorflow的broadcast即廣播機制幫了大忙。而Tensorflow中的廣播機制其實是效仿的numpy中的廣播機制。本篇,我們就來一同研究下numpy和Tensorflow中的廣播機制。
1、numpy廣播原理
1.1 陣列和標量計算時的廣播
標量和數組合並時就會發生簡單的廣播,標量會和陣列中的每一個元素進行計算。
舉個例子:
arr = np.arange(5) arr * 4
得到的輸出為:
array([ 0,4,8,12,16])
這個是很好理解的,我們重點來研究陣列之間的廣播
1.2 陣列之間計算時的廣播
用書中的話來介紹廣播的規則:兩個陣列之間廣播的規則:如果兩個陣列的後緣維度(即從末尾開始算起的維度)的軸長度相等或其中一方的長度為1,則認為他們是廣播相容的,廣播會在缺失和(或)長度為1的維度上進行。
上面的規則挺拗口的,我們舉幾個例子吧:
二維的情況
假設有一個二維陣列,我們想要減去它在0軸和1軸的均值,這時的廣播是什麼樣的呢。
我們先來看減去0軸均值的情況:
arr = np.arange(12).reshape(4,3) arr-arr.mean(0)
輸出的結果為:
array([[-4.5,-4.5,-4.5],
[-1.5,-1.5,-1.5],[ 1.5,1.5,1.5],
[ 4.5,4.5,4.5]])
0軸的平均值為[4.5,5.5,6.5],形狀為(3,),而原陣列形狀為(4,3),在進行廣播時,從後往前比較兩個陣列的形狀,首先是3=3,滿足條件而繼續比較,這時候發現其中一個數組的形狀陣列遍歷完成,因此會在缺失軸即0軸上進行廣播。
可以理解成將均值陣列在0軸上覆制4份,變成形狀(4,3)的陣列,再與原陣列進行計算。
書中的圖形象的表示了這個過程(資料不一樣請忽略):
我們再來看一下減去1軸平均值的情況,即每行都減去該行的平均值:
arr - arr.mean(1)
此時報錯了:
我們再來念叨一遍我們的廣播規則,均值陣列的形狀為(4,3),按照比較規則,4 != 3,因此不符合廣播的條件,因此報錯。
正確的做法是什麼呢,因為原陣列在0軸上的形狀為4,我們的均值陣列必須要先有一個值能夠跟3比較同時滿足我們的廣播規則,這個值不用多想,就是1。因此我們需要先將均值陣列變成(4,1)的形狀,再去進行運算:
arr-arr.mean(1).reshape((4,1))
得到正確的結果:
array([[-1.,0.,1.],[-1.,1.]])
三維的情況
理解了二維的情況,我們也就能很快的理解三維陣列的情況。
首先看下圖:
根據廣播原則分析:arr1的shape為(3,2),arr2的shape為(4,2),它們的後緣軸長度都為(4,2),所以可以在0軸進行廣播。因此,arr2在0軸上覆制三份,shape變為(3,2),再進行計算。
不只是0軸,1軸和2軸也都可以進行廣播。但形狀必須滿足一定的條件。舉個例子來說,我們arr1的shape為(8,5,3),想要在0軸上廣播的話,arr2的shape是(1,3)或者(5,3),想要在1軸上進行廣播的話,arr2的shape是(8,1,3),想要在2軸上廣播的話,arr2的shape必須是(8,1)。
我們來寫幾個例子吧:
arr2 = np.arange(24).reshape((2,3,4)) arr3_0 = np.arange(12).reshape((3,4)) print("0軸廣播") print(arr2 - arr3_0) arr3_1 = np.arange(8).reshape((2,4)) print("1軸廣播") print(arr2 - arr3_1) arr3_2 = np.arange(6).reshape((2,1)) print("2軸廣播") print(arr2 - arr3_2)
輸出為:
0軸廣播
[[[ 0 0 0 0]
[ 0 0 0 0]
[ 0 0 0 0]][[12 12 12 12]
[12 12 12 12]
[12 12 12 12]]]
1軸廣播
[[[ 0 0 0 0]
[ 4 4 4 4]
[ 8 8 8 8]][[ 8 8 8 8]
[12 12 12 12]
[16 16 16 16]]]
2軸廣播
[[[ 0 1 2 3]
[ 3 4 5 6]
[ 6 7 8 9]][[ 9 10 11 12]
[12 13 14 15]
[15 16
17 18]]]
如果我們想在兩個軸上進行廣播,那arr2的shape要滿足什麼條件呢?
arr1.shape | 廣播軸 | arr2.shape |
---|---|---|
(8,3) | 0,1 | (3,),(1,3),3) |
(8,2 | (5,1),1) | |
(8,3) | 1,2 | (8,1) |
具體的例子就不給出啦,嘻嘻。
2、Tensorflow 廣播舉例
Tensorflow中的廣播機制和numpy是一樣的,因此我們給出一些簡單的舉例:
二維的情況
sess = tf.Session() a = tf.Variable(tf.random_normal((2,0.1)) b = tf.Variable(tf.random_normal((2,0.1)) c = a - b sess.run(tf.global_variables_initializer()) sess.run(c)
輸出為:
array([[-0.1419442,0.14135399,0.22752595],
[ 0.1382471,0.28228047,0.13102233]],dtype=float32)
三維的情況
sess = tf.Session() a = tf.Variable(tf.random_normal((2,4),0.1)) c = a - b sess.run(tf.global_variables_initializer()) sess.run(c)
輸出為:
array([[[-0.0154749,-0.02047186,-0.01022427,-0.08932371],
[-0.12693939,-0.08069084,-0.15459496,0.09405404],
[ 0.09730847,0.06936138,0.04050628,0.15374713]],[[-0.02691782,-0.26384184,0.05825682,-0.07617196],
[-0.02653179,-0.01997554,-0.06522765,0.03028341],
[-0.07577246,0.03199019,0.0321,-0.12571403]]],dtype=float32)
錯誤示例
sess = tf.Session() a = tf.Variable(tf.random_normal((2,0.1)) c = a - b sess.run(tf.global_variables_initializer()) sess.run(c)
輸出為:
ValueError: Dimensions must be equal,but are 3 and 2 for 'sub_2' (op: 'Sub') with input shapes: [2,4],[2,4].
到此這篇關於探祕TensorFlow 和 NumPy 的 Broadcasting 機制的文章就介紹到這了,更多相關TensorFlow 和NumPy 的Broadcasting 內容請搜尋我們以前的文章或繼續瀏覽下面的相關文章希望大家以後多多支援我們!