pytorch高階操作
阿新 • • 發佈:2020-09-02
pytorch高階操作
where函式
torch.where(condition,x,y)
可能新生成的tensor一部分來自x,一部分來自y,但是是沒有規律的
例子:假設一個tensor表示識別概率,大於0.5表示1,小於0.5表示0
a = torch.rand(2,2) print(a) tensor([[0.9872, 0.9270], [0.6795, 0.0959]]) aa = torch.zeros(2,2) bb = torch.ones(2,2) answer = torch.where(a>0.5,aa,bb) print(answer) tensor([[0., 0.], [0., 1.]])
gather函式
實際就是一個查表的函式
比如像手寫數字的識別,【4,10】4張圖片,最後識別出每張圖片中10個概率最大的index(一般index為幾這個數字就是幾),但是如果我們的標籤不是1~10,而是另外有一張表來對應,不同的index對應不同的標籤,這時就可以使用gather函式
例子:
prob = torch.rand(4,10) idx = prob.topk(3,dim=1) idx1 = idx[1] print(idx1) tensor([[1, 3, 4], [2, 0, 3], [5, 4, 2], [9, 4, 5]]) label = torch.arange(10)+100#為了方面隨便初始化的label print(torch.gather(label.expand(4,10),dim=1,index=idx1.long())) tensor([[101, 103, 104], [102, 100, 103], [105, 104, 102], [109, 104, 105]])