tf.concat中axis的含義與使用詳解
阿新 • • 發佈:2020-02-09
tensorflow中tf.concat的axis的使用我一直理解的比較模糊,這次做個筆記理下自己的思路。
import tensorflow as tf tf.enable_eager_execution() import numpy as np
先生成兩個矩陣m1,和m2,大小為兩行三列
m1 = np.random.rand(2,3) # m1.shape (2,3) m1 >>array([[0.44529968,0.42451167,0.07463199],[0.35787143,0.22926186,0.34583839]]) m2 = np.random.rand(2,3) # m2.shape (2,3) m2 >>array([[0.92811531,0.6180391,0.71969461],[0.00564108,0.55381637,0.17155987]])
接下來採用tf.concat進行連線,簡單來說,axis=0實際就是按行拼接,axis=1就是按列拼接
# axis = 0 m3 = tf.concat([m1,m2],axis=0) m3 >> array([[0.44529968,0.34583839],[0.92811531,0.17155987]]) m3.shape >> (4,3) # axis = 1 m4 = tf.concat([m1,axis=1) m4 >>array([[0.44529968,0.07463199,0.92811531,0.34583839,0.00564108,0.17155987]]) m4.shape >>(2,6)
但這實際上這隻有在我們的輸入是二維矩陣時才可以這樣理解。axis的實際含義是根據axis指定的維度進行連線,如矩陣m1的維度為(2,3),那麼axis=0就代表了第一個維度‘2',因此,將m1和m2按照第一個維度進行連線,得到的新的矩陣就是將第一維度進行相加,其餘維度不變,即維度變成了(4,3).
同理,axis=1時就是將矩陣的第二維度進行合併,其餘維度不變,即維度變成了(2,6)。
接下來處理三個維度的資料,這也是我們在神經網路資料中經常要用到的,增加的一個維度通常代表了batch_size. 如下面的m5,batch_size=5,可以理解為每個樣本是個2*3的矩陣,一次將5個樣本放在一起。
m5 = np.random.rand(5,2,3) m6 = np.random.rand(5,3) m5 >>array([[[0.04347217,0.03368232,0.36017024],[0.74223151,0.06609717,0.38155531]],[[0.50602728,0.355745,0.93379797],[0.97572621,0.53745311,0.66461841]],[[0.92832972,0.02441683,0.48436203],[0.69651043,0.24194495,0.64623769]],[[0.66667596,0.60053027,0.2970753 ],[0.13281764,0.29326326,0.32393028]],[[0.40892782,0.48516547,0.02298178],[0.51239083,0.40151008,0.29913204]]]) m6 >>array([[[0.58001909,0.56925704,0.09798246],[0.20841893,0.62683633,0.17923217]],[[0.91216164,0.0200782,0.3986682 ],[0.86687006,0.83730576,0.48443545]],[[0.65641654,0.59786311,0.2055584 ],[0.65391822,0.74093133,0.02416627]],[[0.80778861,0.22644312,0.91610686],[0.0789411,0.86955002,0.41437046]],[[0.97821668,0.97118328,0.97714882],[0.21543173,0.06964724,0.35360077]]])
在這種情況下,axis=0代表的第一個維度的含義就不再是之前認為的行的概念了,現在m5的第一維度的值是5,代表的是batch_size。仍然按照之前的理解,如果設定axis=0,axis=0就是將第一維度進行相加,其餘維度不變,因此我們可以得到新的維度為(10,3)。
m7 = tf.concat([m5,m6],axis=0) m7 >> array([[[0.04347217,0.29913204]],[[0.58001909,0.35360077]]]) m7.shape >>(10,3)
同理,也可以進行axis=1,axis=2的concat操作。
此外,axis的值也可以設定為負數,如axis=-1實際上就是指倒數第一個維度,如m5的倒數第一個維度的值就是‘3'。因此,axis=2的操作和axis=-1的操作是等價的。
以上這篇tf.concat中axis的含義與使用詳解就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。