1. 程式人生 > >通過一個一維vector拿到tensor對應元素

通過一個一維vector拿到tensor對應元素

我的需求是有一個15*4的tensor, 有一個15*1的index vector, 希望從tensor裡面拿到對應的15*1的vector

查了半天gather未果, 因為gather()只能拿一個slice

gather_nd也不是很好的解決方法, 不過可以實現, 需要把index變成一個15*2的陣列, 來specify原tensor的元素

 

後來看到用one hot可以解決, ref: https://stackoverflow.com/questions/39684415/tensorflow-getting-elements-of-every-row-for-specific-columns

 

Ask:

If A is a TensorFlow variable like so

A = tf.Variable([[1, 2], [3, 4]])

and index is another variable

index = tf.Variable([0, 1])

I want to use this index to select columns in each row. In this case, item 0 from first row and item 1 from second row.

If A was a Numpy array then to get the columns of corresponding rows mentioned in index we can do

x = A[np.arange(A.shape[0]), index]

and the result would be

[1, 4]

What is the TensorFlow equivalent operation/operations for this? I know TensorFlow doesn't support many indexing operations. What would be the work around if it cannot be done directly?

 

Answer1:

You can extend your column indices with row indices and then use gather_nd:

import tensorflow as tf

A = tf.constant([[1, 2], [3, 4]])
indices = tf.constant([1, 0])

# prepare row indices
row_indices = tf.range(tf.shape(indices)[0])

# zip row indices with column indices
full_indices = tf.stack([row_indices, indices], axis=1)

# retrieve values by indices
S = tf.gather_nd(A, full_indices)

session = tf.InteractiveSession()
session.run(S)

 

Answer2:

You can use one hot method to create a one_hot array and use it as a boolean mask to select the indices you'd like.

A = tf.Variable([[1, 2], [3, 4]])
index = tf.Variable([0, 1])

one_hot_mask = tf.one_hot(index, A.shape[1], on_value = True, off_value = False, dtype = tf.bool)
output = tf.boolean_mask(A, one_hot_mask)