tensorflow 利用placeholder選擇每個batch裡的sub-tensor 例項
阿新 • • 發佈:2018-11-29
只能reshape成-1,然後gather的時候累加batch_size去取
import tensorflow as tf def gather_indexes_2d(sequence_tensor, positions): sequence_shape = sequence_tensor.shape.as_list() batch_size = sequence_shape[0] seq_length = sequence_shape[1] flat_offsets = tf.reshape( tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1]) flat_positions = tf.reshape(positions + flat_offsets, [-1]) flat_sequence_tensor = tf.reshape(sequence_tensor, [batch_size * seq_length]) output_tensor = tf.gather(flat_sequence_tensor, flat_positions) return output_tensor value = [[0,1],[2,3],[4,5]] init = tf.constant_initializer(value) v = tf.get_variable('value', shape=[3,2], initializer=init,dtype=tf.int32) p = tf.placeholder(shape=[3], dtype=tf.int32) v_ = gather_indexes_2d(v,p) init = tf.initialize_all_variables() sess = tf.Session() sess.run(init) print(sess.run(v_,feed_dict={p:[1,1,0]}))
列印結果[1 3 4]
rank3的情況:
import tensorflow as tf def gather_indexes_3d(sequence_tensor, positions): sequence_shape = sequence_tensor.shape.as_list() batch_size = sequence_shape[0] seq_length = sequence_shape[1] width = sequence_shape[2] flat_offsets = tf.reshape( tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1]) flat_positions = tf.reshape(positions + flat_offsets, [-1]) flat_sequence_tensor = tf.reshape(sequence_tensor, [batch_size * seq_length, width]) output_tensor = tf.gather(flat_sequence_tensor, flat_positions) return output_tensor v = tf.constant([[[1,1],[2,2],[3,3]],[[4,4],[5,5],[6,6]]]) # [2,3,2] p = tf.placeholder(shape=[2], dtype=tf.int32) v_ = gather_indexes_3d(v,p) init = tf.initialize_all_variables() sess = tf.Session() sess.run(init) print(sess.run(v_,feed_dict={p:[1,0]}))
列印結果
[[2 2]
[4 4]]