tf.strided_slice函式(類似K.shape(feats)[1:3])
阿新 • • 發佈:2018-12-21
在keras_yolo中model函式下grid_shape = K.shape(feats)[1:3] grid_shape: <tf.Tensor ‘strided_slice:0’ shape=(0,) dtype=int32> cifar10的例子中也有。
來把輸入變個型,可以看成3維的tensor,從外向裡為1,2,3維 (維數的判斷順序為z軸–y軸–x軸)
[
[ [1,1,1] [2,2,2] ] [ [3,3,3] [4,4,4] ] [ [5,5,5] [6,6,6] ]
]
以tf.strided_slice(input, [0,0,0], [2,2,2], [1,2,1])呼叫為例,start = [0,0,0] , end = [2,2,2], stride = [1,2,1],求一個[start, end)的一個片段,注意end為開區間
第1維 start = 0 , end = 2, stride = 1, 所以取 0 , 1行,此時的輸出
output1=
[
[
[1,1,1]
[2,2,2]
]
[
[3,3,3]
[4,4,4]
]
] 第2維時, start = 0 , end = 2 , stride = 2, 所以只能取0行,此時的輸出
output2=
[
[
[1,1,1]
]
[
[3,3,3]
]
] 第3維的時候,start = 0, end = 2, stride = 1, 可以取0,1行,此時得到的就是最後的輸出
[
[
[1,1]
]
[
[3,3]
]
] 整理之後最終的輸出為:
[[[1,1],[3,3]]]
更多例子:
t = tf.constant([[[1, 1, 1], [2, 2, 2]],
[[3, 3, 3], [4, 4, 4]],
[[5, 5 , 5], [6, 6, 6]]])
tf.strided_slice(t, [1, 0, 0], [2, 1, 3], [1, 1, 1]) # [[[3, 3, 3]]]
#shape=(1, 1, 3)
tf.strided_slice(t, [1, 0, 0], [2, 2, 3], [1, 1, 1]) # [[[3, 3, 3],
shape=(1, 2, 3) # [4, 4, 4]]]
tf.strided_slice(t, [1, -1, 0], [2, -3, 3], [1, -1, 1]) # [[[4, 4, 4],
shape=(1, 2, 3) # [3, 3, 3]]]