1. 程式人生 > >tensorflow 利用placeholder選擇每個batch裡的sub-tensor 例項

tensorflow 利用placeholder選擇每個batch裡的sub-tensor 例項

只能reshape成-1,然後gather的時候累加batch_size去取

import tensorflow as tf

def gather_indexes_2d(sequence_tensor, positions):
  sequence_shape = sequence_tensor.shape.as_list()
  batch_size = sequence_shape[0]
  seq_length = sequence_shape[1]

  flat_offsets = tf.reshape(
      tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1])
  flat_positions = tf.reshape(positions + flat_offsets, [-1])
  flat_sequence_tensor = tf.reshape(sequence_tensor,
                                    [batch_size * seq_length])
  output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
  return output_tensor

value = [[0,1],[2,3],[4,5]]
init = tf.constant_initializer(value)
v = tf.get_variable('value', shape=[3,2], initializer=init,dtype=tf.int32)

p = tf.placeholder(shape=[3], dtype=tf.int32)

v_ = gather_indexes_2d(v,p)

init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
print(sess.run(v_,feed_dict={p:[1,1,0]}))

列印結果[1 3 4]

rank3的情況:

import tensorflow as tf

def gather_indexes_3d(sequence_tensor, positions):
  sequence_shape = sequence_tensor.shape.as_list()
  batch_size = sequence_shape[0]
  seq_length = sequence_shape[1]
  width = sequence_shape[2]

  flat_offsets = tf.reshape(
      tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1])
  flat_positions = tf.reshape(positions + flat_offsets, [-1])
  flat_sequence_tensor = tf.reshape(sequence_tensor,
                                    [batch_size * seq_length, width])
  output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
  return output_tensor


v = tf.constant([[[1,1],[2,2],[3,3]],[[4,4],[5,5],[6,6]]]) # [2,3,2]
p = tf.placeholder(shape=[2], dtype=tf.int32)

v_ = gather_indexes_3d(v,p)

init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
print(sess.run(v_,feed_dict={p:[1,0]}))

列印結果
[[2 2]
[4 4]]