Tensorflow - tf.expand_dims 學習
阿新 • • 發佈:2018-12-15
API: https://tensorflow.google.cn/api_docs/python/tf/expand_dims?hl=zh-cn
tf.expand_dims(
input,
axis=None,
name=None,
dim=None
)
在input的axis位置插入一維的張量
這個操作在input的維度中索引為axis
的位置插入一維張量。維度索引axis
從零開始; 如果指定負數,axis
則從末尾向後計數
例子:
# 't' is a tensor of shape [2] tf.shape(tf.expand_dims(t, 0)) # [1, 2] tf.shape(tf.expand_dims(t, 1)) # [2, 1] tf.shape(tf.expand_dims(t, -1)) # [2, 1] # 't2' is a tensor of shape [2, 3, 5] tf.shape(tf.expand_dims(t2, 0)) # [1, 2, 3, 5] tf.shape(tf.expand_dims(t2, 2)) # [2, 3, 1, 5] tf.shape(tf.expand_dims(t2, 3)) # [2, 3, 5, 1]
這個操作在一個batch裡面插入一個一維元素是很好用的