1. 程式人生 > 程式設計 >tf.concat中axis的含義與使用詳解

tf.concat中axis的含義與使用詳解

tensorflow中tf.concat的axis的使用我一直理解的比較模糊,這次做個筆記理下自己的思路。

import tensorflow as tf
tf.enable_eager_execution()
import numpy as np

先生成兩個矩陣m1,和m2,大小為兩行三列

m1 = np.random.rand(2,3) # m1.shape (2,3)
m1
>>array([[0.44529968,0.42451167,0.07463199],[0.35787143,0.22926186,0.34583839]])
m2 = np.random.rand(2,3) # m2.shape (2,3)
m2
>>array([[0.92811531,0.6180391,0.71969461],[0.00564108,0.55381637,0.17155987]])

接下來採用tf.concat進行連線,簡單來說,axis=0實際就是按行拼接,axis=1就是按列拼接

# axis = 0
m3 = tf.concat([m1,m2],axis=0)
m3
>> array([[0.44529968,0.34583839],[0.92811531,0.17155987]])
m3.shape
>> (4,3)

# axis = 1
m4 = tf.concat([m1,axis=1)
m4
>>array([[0.44529968,0.07463199,0.92811531,0.34583839,0.00564108,0.17155987]])
m4.shape
>>(2,6)

但這實際上這隻有在我們的輸入是二維矩陣時才可以這樣理解。axis的實際含義是根據axis指定的維度進行連線,如矩陣m1的維度為(2,3),那麼axis=0就代表了第一個維度‘2',因此,將m1和m2按照第一個維度進行連線,得到的新的矩陣就是將第一維度進行相加,其餘維度不變,即維度變成了(4,3).

同理,axis=1時就是將矩陣的第二維度進行合併,其餘維度不變,即維度變成了(2,6)。

接下來處理三個維度的資料,這也是我們在神經網路資料中經常要用到的,增加的一個維度通常代表了batch_size. 如下面的m5,batch_size=5,可以理解為每個樣本是個2*3的矩陣,一次將5個樣本放在一起。

m5 = np.random.rand(5,2,3)
m6 = np.random.rand(5,3)
m5
>>array([[[0.04347217,0.03368232,0.36017024],[0.74223151,0.06609717,0.38155531]],[[0.50602728,0.355745,0.93379797],[0.97572621,0.53745311,0.66461841]],[[0.92832972,0.02441683,0.48436203],[0.69651043,0.24194495,0.64623769]],[[0.66667596,0.60053027,0.2970753 ],[0.13281764,0.29326326,0.32393028]],[[0.40892782,0.48516547,0.02298178],[0.51239083,0.40151008,0.29913204]]])
m6
>>array([[[0.58001909,0.56925704,0.09798246],[0.20841893,0.62683633,0.17923217]],[[0.91216164,0.0200782,0.3986682 ],[0.86687006,0.83730576,0.48443545]],[[0.65641654,0.59786311,0.2055584 ],[0.65391822,0.74093133,0.02416627]],[[0.80778861,0.22644312,0.91610686],[0.0789411,0.86955002,0.41437046]],[[0.97821668,0.97118328,0.97714882],[0.21543173,0.06964724,0.35360077]]])

在這種情況下,axis=0代表的第一個維度的含義就不再是之前認為的行的概念了,現在m5的第一維度的值是5,代表的是batch_size。仍然按照之前的理解,如果設定axis=0,axis=0就是將第一維度進行相加,其餘維度不變,因此我們可以得到新的維度為(10,3)。

m7 = tf.concat([m5,m6],axis=0)
m7
>> array([[[0.04347217,0.29913204]],[[0.58001909,0.35360077]]])
m7.shape
>>(10,3)

同理,也可以進行axis=1,axis=2的concat操作。

此外,axis的值也可以設定為負數,如axis=-1實際上就是指倒數第一個維度,如m5的倒數第一個維度的值就是‘3'。因此,axis=2的操作和axis=-1的操作是等價的。

以上這篇tf.concat中axis的含義與使用詳解就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。