1. 程式人生 > 實用技巧 >pytorch高階操作

pytorch高階操作

pytorch高階操作

where函式

torch.where(condition,x,y)

可能新生成的tensor一部分來自x,一部分來自y,但是是沒有規律的

例子:假設一個tensor表示識別概率,大於0.5表示1,小於0.5表示0

a = torch.rand(2,2)
print(a)

tensor([[0.9872, 0.9270],
        [0.6795, 0.0959]])


aa = torch.zeros(2,2)
bb = torch.ones(2,2)

answer = torch.where(a>0.5,aa,bb)
print(answer)

tensor([[0., 0.],
        [0., 1.]])

gather函式

實際就是一個查表的函式

比如像手寫數字的識別,【4,10】4張圖片,最後識別出每張圖片中10個概率最大的index(一般index為幾這個數字就是幾),但是如果我們的標籤不是1~10,而是另外有一張表來對應,不同的index對應不同的標籤,這時就可以使用gather函式

例子:

prob = torch.rand(4,10)

idx = prob.topk(3,dim=1)
idx1 = idx[1]

print(idx1)

tensor([[1, 3, 4],
        [2, 0, 3],
        [5, 4, 2],
        [9, 4, 5]])

label = torch.arange(10)+100#為了方面隨便初始化的label

print(torch.gather(label.expand(4,10),dim=1,index=idx1.long()))


tensor([[101, 103, 104],
        [102, 100, 103],
        [105, 104, 102],
        [109, 104, 105]])