TensorFlow中張量連線操作tf.concat用法詳解
阿新 • • 發佈:2018-12-13
一、環境
TensorFlow API r1.12
CUDA 9.2 V9.2.148
Python 3.6.3
二、官方說明
按指定軸(axis)進行張量連線操作(Concatenates Tensors)
tf.concat(
values,
axis,
name='concat'
)
輸入:
(1)values:多個張量組成的列表或者單個張量
(2)axis:0維整形張量(整數),定義按照按個數據軸進行張量連線操作,其範圍是[-輸入張量的階,+輸入張量的階]。[0,輸入張量的階]範圍內的正數表示按照指定的axis軸進行連線操作,在[-輸入張量的階,0]之間的負數表示按照指定的(axis + 輸入張量的階)的軸進行連線操作
(3)name:可選引數,定義該張量連線操作的名稱
輸出:
輸入張量按照指定軸連線後的一個結果張量
三、例項
(1)單個張量作為輸入
>>> t1 = [[1,2,3],[4,5,6]] >>> con1 = tf.concat(t1,0) >>> shape1 = tf.shape(con1) >>> with tf.Session() as sess: ... print(sess.run(con1)) ... print(sess.run(shape1)) ... [1 2 3 4 5 6] [6]
(2)多個張量組成的列表作為輸入
按照0軸(行)進行連線:
>>> t1 = [[1,2,3],[4,5,6]] >>> t2 = [[7,8,9],[10,11,12]] >>> con2 = tf.concat([t1,t2],0) >>> shape2 = tf.shape(con2) >>> with tf.Session() as sess: ... print(sess.run(con2)) ... print(sess.run(shape2)) ... [[ 1 2 3] [ 4 5 6] [ 7 8 9] [10 11 12]] [4 3]
按照1軸(列)進行連線:
>>> t1 = [[1,2,3],[4,5,6]]
>>> t2 = [[7,8,9],[10,11,12]]
>>> con3 = tf.concat([t1,t2],1)
>>> shape3 = tf.shape(con3)
>>> with tf.Session() as sess:
... print(sess.run(con3))
... print(sess.run(shape3))
...
[[ 1 2 3 7 8 9]
[ 4 5 6 10 11 12]]
[2 6]
>>>
按照-1軸(列)進行連線:
>>> t1 = [[1,2,3],[4,5,6]]
>>> t2 = [[7,8,9],[10,11,12]]
>>> con4 = tf.concat([t1,t2],-1)
>>> shape4 = tf.shape(con4)
>>> with tf.Session() as sess:
... print(sess.run(con4))
... print(sess.run(shape4))
...
[[ 1 2 3 7 8 9]
[ 4 5 6 10 11 12]]
[2 6]
注意:如果想沿著一個新軸連線張量,則考慮使用stcak
不建議使用:tf.concat([tf.expand_dims(t, axis) for t in tensors],axis)
推薦使用:tf.stack(tensors, axis=axis)