1. 程式人生 > >Tensorflow - tf.expand_dims 學習

Tensorflow - tf.expand_dims 學習

API: https://tensorflow.google.cn/api_docs/python/tf/expand_dims?hl=zh-cn 

tf.expand_dims(
    input,
    axis=None,
    name=None,
    dim=None
)

在input的axis位置插入一維的張量

這個操作在input的維度中索引為axis的位置插入一維張量。維度索引axis從零開始; 如果指定負數,axis則從末尾向後計數

例子:

# 't' is a tensor of shape [2]
tf.shape(tf.expand_dims(t, 0))  # [1, 2]
tf.shape(tf.expand_dims(t, 1))  # [2, 1]
tf.shape(tf.expand_dims(t, -1))  # [2, 1]

# 't2' is a tensor of shape [2, 3, 5]
tf.shape(tf.expand_dims(t2, 0))  # [1, 2, 3, 5]
tf.shape(tf.expand_dims(t2, 2))  # [2, 3, 1, 5]
tf.shape(tf.expand_dims(t2, 3))  # [2, 3, 5, 1]

 這個操作在一個batch裡面插入一個一維元素是很好用的