pytorch 常見函式理解
阿新 • • 發佈:2018-11-28
gather
>>> a = torch.Tensor([[1,2],[3,4]]) >>> a tensor([[ 1., 2.], [ 3., 4.]]) >>> torch.gather(a,1,torch.LongTensor([ ... [0,0], ... [1,0]])) tensor([[ 1., 1.], [ 4., 3.]]) #1代表按照第1維度進行計算 #第一維也就是按照行,第一行[0,0]代表,新的tensor的第一行的兩個元素,分別是a第一行的的第0個和第0個元素 #第一維也就是按照行,第二行[1,0]代表,新的tensor的第二行的兩個元素,分別是a第二行的第1個和第0個元素>>> torch.gather(a,0,torch.LongTensor([ ... [0,0], ... [1,0]])) tensor([[ 1., 2.], [ 3., 2.]]) #0代表按照第0維度進行計算 #第0維也就是按照列,第二列[0,0]代表,新的tensor的第二列的兩個元素,分別是a第二列的第0個和第0個元素
squeeze
將維度為1的壓縮掉。如size為(3,1,1,2),壓縮之後為(3,2)
import torch a=torch.randn(2,1,1,3) print(a) print(a.squeeze())
輸出:
tensor([[[[-0.2320, 0.9513, 1.1613]]], [[[ 0.0901, 0.9613, -0.9344]]]])
tensor([[-0.2320, 0.9513, 1.1613], [ 0.0901, 0.9613, -0.9344]])
expand
擴充套件某個size為1的維度。如(2,2,1)擴充套件為(2,2,3)
import torch x=torch.randn(2,2,1) print(x) y=x.expand(2,2,3) print(y)
輸出:
tensor([[[ 0.0608], [ 2.2106]], [[-1.9287], [ 0.8748]]]) tensor([[[ 0.0608, 0.0608, 0.0608], [ 2.2106, 2.2106, 2.2106]], [[-1.9287, -1.9287, -1.9287], [ 0.8748, 0.8748, 0.8748]]])
參考:https://blog.csdn.net/hbu_pig/article/details/81454503
sum
size為(m,n,d)的張量,dim=1時,輸出為size為(m,d)的張量
import torch a=torch.tensor([[[1,2,3],[4,8,12]],[[1,2,3],[4,8,12]]]) print(a.sum()) print(a.sum(dim=1))
輸出:
tensor(60) tensor([[ 5, 10, 15], [ 5, 10, 15]])
contiguous
返回一個記憶體為連續的張量,如本身就是連續的,返回它自己。一般用在view()函式之前,因為view()要求呼叫張量是連續的。可以通過is_contiguous檢視張量記憶體是否連續。
import torch a=torch.tensor([[[1,2,3],[4,8,12]],[[1,2,3],[4,8,12]]]) print(a.is_contiguous) print(a.contiguous().view(4,3))
輸出:
<built-in method is_contiguous of Tensor object at 0x7f4b5e35afa0> tensor([[ 1, 2, 3], [ 4, 8, 12], [ 1, 2, 3], [ 4, 8, 12]])
softmax
假設陣列V有C個元素。對其進行softmax等價於將V的每個元素的指數除以所有元素的指數之和。這會使值落在區間(0,1)上,並且和為1。
import torch import torch.nn.functional as F a=torch.tensor([[1.,1],[2,1],[3,1],[1,2],[1,3]]) b=F.softmax(a,dim=1) print(b)
輸出:
tensor([[ 0.5000, 0.5000], [ 0.7311, 0.2689], [ 0.8808, 0.1192], [ 0.2689, 0.7311], [ 0.1192, 0.8808]])
max
返回最大值,或指定維度的最大值以及index
import torch a=torch.tensor([[.1,.2,.3], [1.1,1.2,1.3], [2.1,2.2,2.3], [3.1,3.2,3.3]]) print(a.max(dim=1)) print(a.max())
輸出:
(tensor([ 0.3000, 1.3000, 2.3000, 3.3000]), tensor([ 2, 2, 2, 2])) tensor(3.3000)
argmax
返回最大值的index
import torch a=torch.tensor([[.1,.2,.3], [1.1,1.2,1.3], [2.1,2.2,2.3], [3.1,3.2,3.3]]) print(a.argmax(dim=1)) print(a.argmax(dim=0)) print(a.argmax())
輸出:
tensor([ 2, 2, 2, 2])
tensor([ 3, 3, 3]) tensor(11)