1. 程式人生 > 程式設計 >tensorflow中tf.slice和tf.gather切片函式的使用

tensorflow中tf.slice和tf.gather切片函式的使用

tf.slice(input_,begin,size,name=None):按照指定的下標範圍抽取連續區域的子集

tf.gather(params,indices,validate_indices=None,name=None):按照指定的下標集合從axis=0中抽取子集,適合抽取不連續區域的子集

輸出:

input = [[[1,1,1],[2,2,2]],[[3,3,3],[4,4,4]],[[5,5,5],[6,6,6]]]
tf.slice(input,[1,0],3]) ==> [[[3,3]]]
tf.slice(input,4]]]
tf.slice(input,3]],5]]]
           
tf.gather(input,[0,2]) ==> [[[1,6]]]

假設我們要從input中抽取[[[3,3]]],這個輸出在inputaxis=0的下標是1,axis=1的下標是0,axis=2的下標是0-2,所以begin=[1,0],size=[1,3]。

假設我們要從input中抽取[[[3,4]]],這個輸出在inputaxis=0的下標是1,axis=1的下標是0-1,axis=2的下標是0-2,所以begin=[1,[5,5]]],這個輸出在inputaxis=0的下標是1-2,axis=1的下標是0,axis=2的下標是0-2,所以begin=[1,0],size=[2,3]。

假設我們要從input中抽取[[[1,6]]],這個輸出在input的axis=0的下標是[0,2],不連續,可以用tf.gather抽取。input[0]和input[2]

以上這篇tensorflow中tf.slice和tf.gather切片函式的使用就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。