tf.gather_nd 例項
阿新 • • 發佈:2019-01-06
import tensorflow as tf value = [[[1,2,3],[11,22,33]],[[4,5,6],[44,55,66]]] init = tf.constant_initializer(value) input = tf.get_variable('input', shape=[2,2,3], initializer=init) value = [[[0,1],[1,0]]] init = tf.constant_initializer(value) index = tf.get_variable('index', shape=[2,1,2], initializer=init, dtype=tf.int32) value = [[0,1],[1,0]] init = tf.constant_initializer(value) index2 = tf.get_variable('index2', shape=[2,2], initializer=init, dtype=tf.int32) value = [0,1] init = tf.constant_initializer(value) index3 = tf.get_variable('index3', shape=[2], initializer=init, dtype=tf.int32) value = [0,1,1] init = tf.constant_initializer(value) index4 = tf.get_variable('index4', shape=[3], initializer=init, dtype=tf.int32) result = tf.gather_nd(input,index) result2 = tf.gather_nd(input,index2) result3 = tf.gather_nd(input,index3) result4 = tf.gather_nd(input,index4) sess = tf.Session() sess.run(tf.global_variables_initializer()) print(sess.run(result)) print() print(sess.run(result2)) print() print(sess.run(result3)) print() print(sess.run(result4))
列印結果:
[[[11. 22. 33.]]
[[ 4. 5. 6.]]]
[[11. 22. 33.]
[ 4. 5. 6.]]
[11. 22. 33.]
22.0