tensorflow中tf.slice和tf.gather切片函式的使用
阿新 • • 發佈:2020-01-20
tf.slice(input_,begin,size,name=None):按照指定的下標範圍抽取連續區域的子集
tf.gather(params,indices,validate_indices=None,name=None):按照指定的下標集合從axis=0中抽取子集,適合抽取不連續區域的子集
輸出:
input = [[[1,1,1],[2,2,2]],[[3,3,3],[4,4,4]],[[5,5,5],[6,6,6]]] tf.slice(input,[1,0],3]) ==> [[[3,3]]] tf.slice(input,4]]] tf.slice(input,3]],5]]] tf.gather(input,[0,2]) ==> [[[1,6]]]
假設我們要從input中抽取[[[3,3]]],這個輸出在inputaxis=0的下標是1,axis=1的下標是0,axis=2的下標是0-2,所以begin=[1,0],size=[1,3]。
假設我們要從input中抽取[[[3,4]]],這個輸出在inputaxis=0的下標是1,axis=1的下標是0-1,axis=2的下標是0-2,所以begin=[1,[5,5]]],這個輸出在inputaxis=0的下標是1-2,axis=1的下標是0,axis=2的下標是0-2,所以begin=[1,0],size=[2,3]。
假設我們要從input中抽取[[[1,6]]],這個輸出在input的axis=0的下標是[0,2],不連續,可以用tf.gather抽取。input[0]和input[2]
以上這篇tensorflow中tf.slice和tf.gather切片函式的使用就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。