1. 程式人生 > >tensorflow-tf.concat

tensorflow-tf.concat

tf.concat

 

tf.concat(
    values,
    axis,
    name='concat'
)

按一維連線張量。

沿著維度軸連線張量值的列表。如果values[i].shape=[D0, D1, ... Daxis(i), ...Dn],則連線的結果具有形狀如下:

[D0, D1, ... Raxis, ...Dn]

在此

Raxis = sum(Daxis(i))

也就是說,來自輸入張量的資料沿著軸維度連線。

輸入張量的維數必須匹配,除軸外的所有維度必須相等。

例如:

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Thu Sep  6 10:16:37 2018

@author: myhaspl
"""
t1 = [[1, 2, 3], [4, 5, 6]]
t2 = [[7, 8, 9], [10, 11, 12]]
a=tf.concat([t1, t2], 0)  
b=tf.concat([t1, t2], 1)  
sess=tf.Session()
with sess:
    print sess.run(a)
    print sess.run(b)

[[ 1 2 3]
[ 4 5 6]
[ 7 8 9]
[10 11 12]]
[[ 1 2 3 7 8 9]
[ 4 5 6 10 11 12]]

在python中,axis可以為負值。負軸(axis)x被解釋為從秩(rank)的末尾開始計數,即axis+rank(values)-Th維。

例如:

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Thu Sep  6 10:16:37 2018

@author: myhaspl
"""
import tensorflow as tf
t1 = [[[1, 2], [2, 3]], [[4, 4], [5, 3]]]
t2 = [[[7, 4], [8, 4]], [[2, 10], [15, 11]]]
a=tf.concat([t1, t2], -1)
sess=tf.Session()
with sess:
    print sess.run(a)

[[[ 1 2 7 4]
[ 2 3 8 4]]

[[ 4 4 2 10]
[ 5 3 15 11]]]

注意:如果你在一個新的軸上連線,考慮使用堆疊。例如。

tf.concat([tf.expand_dims(t, axis) for t in tensors], axis)

可寫成

tf.stack(tensors, axis=axis)

引數:

values: 一個Tensor物件的列表或單個tensor
axis:  0-D int32 Tensor.沿著該軸連線,必須在 [-rank(values), rank(values))
在Pyhon中, axis的索引是基於0. 正軸在[0, rank(values))中稱為軸維。負軸是指axis + rank(values)-th維。
name: 操作名字(可選)
返回:

從input tensors連線而成的Tensor結果。