tf.slice()函式用法
阿新 • • 發佈:2018-11-22
import tensorflow as tf t = tf.constant([[1, 1, 1, 0], [2, 2, 0, 0], [3, 0, 0, 0], [4, 0, 0, 0], [5, 0, 0, 0]] ) print(t.get_shape()) length = tf.constant([3,2,1,1,1]) print(length) # for each in t: # t_slice = tf.slice(t, [0, 0], [[3,2,1,1,1],0]) # # print(t_slice) t_1 = tf.constant([1,1,1,0]) length_1 = tf.constant(3) slice_1 = tf.slice(t_1,[0],[length_1]) all_slice = [] stacks_t = tf.unstack(t) for i, each_row in enumerate(stacks_t): slice_k = tf.slice(t_1,[0],[length[i]]) all_slice.append(tf.expand_dims(slice_k,0)) # slice_t = tf.stack(all_slice,axis=0) with tf.Session() as sess: # print(sess.run(tf.slice(t, [0, 0], [4, 3]))) print(sess.run(slice_1)) print(sess.run(all_slice)) # print(sess.run(stacks_t))
output:
(5, 4)
Tensor("Const_1:0", shape=(5,), dtype=int32)
[1 1 1]
[array([[1, 1, 1]], dtype=int32), array([[1, 1]], dtype=int32), array([[1]], dtype=int32), array([[1]], dtype=int32), array([[1]], dtype=int32)]