1. 程式人生 > >理解pytorch幾個高階選擇函式(如gather)

理解pytorch幾個高階選擇函式(如gather)

[TOC] ### 1. 引言   最近在刷開源的Pytorch版動手學深度學習,裡面談到幾個高階選擇函式,如index_select,masked_select,gather等。這些函式大多很容易理解,但是對於gather函式,確實有些難理解,官方文件開始也看得一臉懵,感覺不太直觀。下面談談我對這幾個函式的一些理解。 ### 2. 維度的理解   對於numpy和pytorch,其陣列在做維度運算上剛開始可能會給人一種直觀上的誤解,以numpy求矩陣某個維度的最大值為例(pytorch的理解也是一樣的) ``` import numpy as np a = np.arange(1, 13).reshape(3, 4) """ result: a = [[1, 2, 3, 4], [5, 6, 7, 8,], [9, 10, 11, 12]] """ # 對a維度0求最大值 a.max(axis = 0) """ result: [9, 10, 11, 12] """ # 對a維度1求最大值 a.max(axis = 1) """ result: [4, 8, 12] """ ```   如果對a矩陣在維度0上找最大值,根據我們直觀上的經驗應該是[4, 8, 12]。即從[1, 2, 3, 4]找到4,從[5, 6, 7, 8]找到8,從[9, 10, 11, 12]找到12。但是從上面結果來看,numpy運算卻給了我們直觀上認為是列最大值的結果[9, 10, 11, 12]。   實際numpy(pytorch)運算應該理解為**往給定的維度進行移動運算**。還是以維度0為例,維度0上有3個向量,分別為[1, 2, 3, 4],[5, 6, 7, 8]和[9, 10, 11, 12]。往維度0移動,即[1, 2, 3, 4]和[5, 6, 7, 8]逐元素計算最大值,得到[5, 6, 7, 8],再和[9, 10, 11, 12]運算得到結果[9, 10, 11, 12]。   另外,對於維度為3的陣列,在numpy和pytorch中,應該把維度0理解為通道數,維度1和維度2才是對應高和寬。如果是3維陣列對應著用於多輸入通道和單輸出通道的卷積核(維度為U x V x D),那麼4維陣列就對應著用於多輸入通道和多輸出通道的卷積核(維度為U x V x D x P),此時,維度0則為多通道卷積核數量的方向,維度1為通道數,維度2和3才是分別對應高和寬。 ### 3. gather函式 pytorch和numpy中許多函式都涉及維度運算,`gather`也不例外,但是它相對於其他函式更難理解。依然先來看一個例子 ``` import torch a = torch.arange(1, 16).reshape(5, 3) """ result: a = [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14, 15]] """ # 定義兩個index b = torch.tensor([[0, 1, 2], [2, 3, 4], [0, 2, 4]]) c = torch.tensor([[1, 2, 0, 2, 1], [1, 2, 1, 0, 0]]) # axis=0 output1 = a.gather(0, b) """ result: [[1, 5, 9], [7, 11, 15], [1, 8, 15]] """ # axis=1 output2 = a.gather(1, c) """ result: [[2, 3, 1, 3, 2], [5, 6, 5, 4, 4]] """ ``` 上面的例子看起來可能有點複雜,我們來一步步的分析它,先從gather維度為0開始講起。 1. `a.gather(0, b)`分為3個部分,`a`是需要被提取元素的矩陣,`0`代表的是提取的維度為0,`b`是提取元素的索引 - 其中規定b和a是同維張量,即a是2維張量,b也必須是2維張量 2. `0`除了代表往維度0的方向提取元素外,還有一個特權---提取結果output可以在這個維度上的長度與a不同。打個比方,a現在的shape為(5, 3),那麼提取結果output1的shape可以是(1,3),(2, 3),甚至(n, 3)。具體維度0的長度到底為多少由b來決定。 3. 根據`0`的特權,導致了給定的b張量除了維度0外,其他的維度大小必須和a一樣。其中張量`b`實際上包含以下兩個資訊 - b可以利用除用於gather的維度(此處為維度0)外的維度來定位出唯一一個向量,也就是a[:, ?](三維度也是同理的,有a[:, ?1, ?2]),?的取值範圍為a同維度的index。 - 對於上述定位出的向量,通過b中的元素來定位提取向量中的哪一個元素。 - 上面說得可能有點抽象,實際上b中的每個元素都能在a中提取出一個元素。舉個具體點的例子,按照上面所說的,b[0, 0]可以提取a中的一個元素。對於b[0,0],除了維度0外,可以通過維度1來定位出唯一一個向量a[:, 0]。因為b[0, 0]的元素為0,即提取的是a[:, 0]的第0個元素---1,並將其作為output1[0, 0]的提取結果。 下圖給出了維度0和維度1,gather運算的圖示
對於3維或者更高維度的張量gather的原理也是一樣的 ### 4. index_select函式 其他的高階選擇函式都比較容易理解,這裡簡單的提一下。torch.index_select主要是根據傳入的tensor來往給定的axis方向來選取張量 ``` import torch a = torch.arange(9).reshape(3, 3) torch.index_select(a, 0, torch.tensor([0, 2])) """ result: [[0, 1, 2], [6, 7, 8]] """ ``` ### 5. masked_select函式 實際上就是通過掩碼條件來選擇元素,像torch.masked_select(x, x>0.5),實際上是和x[x>0.5]等價的,最後返回的是一維張量 ``` import torch a = torch.rand(5, 3) # 結果和a[a > 0.5]等價 torch.masked_select(a, a>0.5) ``` ### 6. nonzero函式 找到非零元素的index ``` import torch a = torch.eye(3) torch.nonzero(a) """ result: 對應著非零元素的index [[0, 0], [1, 1], [2, 2]]