TensorFlow基礎筆記(4) Tensor Transformation
阿新 • • 發佈:2017-09-30
ces ade col mat none lan flow indices 範圍
https://segmentfault.com/a/1190000008793389
抽取
-
tf.slice(input_, begin, size, name=None)
:按照指定的下標範圍抽取連續區域的子集 -
tf.gather(params, indices, validate_indices=None, name=None)
:按照指定的下標集合從axis=0
中抽取子集,適合抽取不連續區域的子集
begin為下標起始位置,size為獲取個數
input = [[[1, 1, 1], [2, 2, 2]],
[[3, 3, 3], [4, 4, 4]],
[[5, 5, 5], [6, 6, 6]]]
tf.slice(input, [1, 0, 0], [1, 1, 3]) ==> [[[3, 3, 3]]]
tf.slice(input, [1, 0, 0], [1, 2, 3]) ==> [[[3, 3, 3],
[4, 4, 4]]]
tf.slice(input, [1, 0, 0], [2, 1, 3]) ==> [[[3, 3, 3]],
[[5, 5, 5]]]
tf.gather(input, [0, 2]) ==> [[[1, 1, 1], [2, 2, 2]],
[[5, 5, 5], [6, 6, 6]]]
import tensorflow as tf
import numpy as np
input = np.array([[[1, 1, 1], [2, 2, 2]],
[[3, 3, 3], [4, 4, 4]],
[[5, 5, 5], [6, 6, 6]]])
print(input.shape)
print(input)
sess = tf.Session()
out = tf.slice(input, [1, 0, 0], [1, 2, 3])
print (‘out:\n‘,sess.run(out))
輸出:
(3, 2, 3)
[[[1 1 1]
[2 2 2]]
[[3 3 3]
[4 4 4]]
[[5 5 5]
[6 6 6]]]
out:
[[[3 3 3]
[4 4 4]]]
維度為 3 * 2 * 3
第0維度 豎直方向,維度為3
第1維度,豎直方向,維度為1
第2維度,水平方向,維度為2
input[1,1,2] = 3
TensorFlow基礎筆記(4) Tensor Transformation