Tensorflow - tf.tile() 學習
阿新 • • 發佈:2018-12-15
API:https://tensorflow.google.cn/api_docs/python/tf/tile?hl=zh-cn
tf.tile()
用於張量擴充套件
tf.tile(
input,
multiples,
name=None
)
輸入是一個Tensor
multiples的維度與輸入的維度相一致,並標明在哪一個維度上進行擴充套件,擴充套件的方法就是複製為相同的元素,下面的例子可以說明問題:
import tensorflow as tf raw = tf.Variable(tf.random_normal(shape=(2 ,2, 2))) multi1 = tf.tile(raw, multiples=[2, 1, 1]) multi2 = tf.tile(raw, multiples=[1, 2, 1]) multi3 = tf.tile(raw, multiples=[1, 1, 2]) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print(raw.eval()) print('-----------------------------') a = sess.run(multi1) b = sess.run(multi2) c = sess.run(multi3) print(a) print(a.shape) print('-----------------------------') print(b) print(b.shape) print('-----------------------------') print(c) print(c.shape) print('-----------------------------')
#原始 [[[ 0.6948325 -0.16302951] [-0.60185844 0.3866387 ]] [[-0.5528875 -0.06845065] [ 0.24240932 0.72961247]]] ----------------------------- # multiples=[2, 1, 1] [[[ 0.6948325 -0.16302951] [-0.60185844 0.3866387 ]] [[-0.5528875 -0.06845065] [ 0.24240932 0.72961247]] [[ 0.6948325 -0.16302951] [-0.60185844 0.3866387 ]] [[-0.5528875 -0.06845065] [ 0.24240932 0.72961247]]] (4, 2, 2) ----------------------------- # multiples=[1, 2, 1] [[[ 0.6948325 -0.16302951] [-0.60185844 0.3866387 ] [ 0.6948325 -0.16302951] [-0.60185844 0.3866387 ]] [[-0.5528875 -0.06845065] [ 0.24240932 0.72961247] [-0.5528875 -0.06845065] [ 0.24240932 0.72961247]]] (2, 4, 2) ----------------------------- # multiples=[1, 1, 2] [[[ 0.6948325 -0.16302951 0.6948325 -0.16302951] [-0.60185844 0.3866387 -0.60185844 0.3866387 ]] [[-0.5528875 -0.06845065 -0.5528875 -0.06845065] [ 0.24240932 0.72961247 0.24240932 0.72961247]]] (2, 2, 4) -----------------------------