torch.gather的使用及理解
阿新 • • 發佈:2021-01-20
技術標籤:pytorch深度學習神經網路pytorchpython
結論:使用方法
# gather,沿dim指定的軸收集值。
y_hat.gather(1, y.view(-1, 1))# y.view(-1, 1)會變成一列,y_hat的取y作為的索引的值
分步理解:先建立一個2*3的tensor
>>y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
tensor([[0.1000, 0.3000, 0.6000],
[0.3000, 0.2000, 0.5000]])
為了使用gather函式,我們得建立一個tensor作為gather得引數
>>y = torch.LongTensor([0, 2]) tensor([0, 2])
我們需要把y變個形狀
>>y.view(-1, 1)
tensor([[0],
[2]])
先來看看使用得結果
>>y_hat.gather(1, y.view(-1, 1))
tensor([[0.1000],
[0.5000]])
圖解: