1. 程式人生 > 其它 >torch.gather的使用及理解

torch.gather的使用及理解

技術標籤:pytorch深度學習神經網路pytorchpython

結論:使用方法

# gather,沿dim指定的軸收集值。
y_hat.gather(1, y.view(-1, 1))# y.view(-1, 1)會變成一列,y_hat的取y作為的索引的值

分步理解:先建立一個2*3的tensor

>>y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])

tensor([[0.1000, 0.3000, 0.6000],
        [0.3000, 0.2000, 0.5000]])

為了使用gather函式,我們得建立一個tensor作為gather得引數

>>y = torch.LongTensor([0, 2])

tensor([0, 2])

我們需要把y變個形狀

>>y.view(-1, 1)

tensor([[0],
        [2]])

先來看看使用得結果

>>y_hat.gather(1, y.view(-1, 1))
tensor([[0.1000],
        [0.5000]])

圖解: