1. 程式人生 > >tf.gather_nd 例項

tf.gather_nd 例項

import tensorflow as tf

value = [[[1,2,3],[11,22,33]],[[4,5,6],[44,55,66]]]
init = tf.constant_initializer(value)
input = tf.get_variable('input', shape=[2,2,3], initializer=init)

value = [[[0,1],[1,0]]]
init = tf.constant_initializer(value)
index = tf.get_variable('index', shape=[2,1,2], initializer=init, dtype=tf.int32)

value = [[0,1],[1,0]]
init = tf.constant_initializer(value)
index2 = tf.get_variable('index2', shape=[2,2], initializer=init, dtype=tf.int32)

value = [0,1]
init = tf.constant_initializer(value)
index3 = tf.get_variable('index3', shape=[2], initializer=init, dtype=tf.int32)

value = [0,1,1]
init = tf.constant_initializer(value)
index4 = tf.get_variable('index4', shape=[3], initializer=init, dtype=tf.int32)

result = tf.gather_nd(input,index)
result2 = tf.gather_nd(input,index2)
result3 = tf.gather_nd(input,index3)
result4 = tf.gather_nd(input,index4)

sess = tf.Session()
sess.run(tf.global_variables_initializer())
print(sess.run(result))
print()
print(sess.run(result2))
print()
print(sess.run(result3))
print()
print(sess.run(result4))

列印結果:

[[[11. 22. 33.]]
[[ 4. 5. 6.]]]

[[11. 22. 33.]
[ 4. 5. 6.]]

[11. 22. 33.]

22.0