8 Torch 中 view() & size() 用法
阿新 • • 發佈:2021-10-21
Torch 中 view() & size()
在閱讀論文原始碼過程中,經常會看到如下的命令:
x = x.view(x.size(0), -1) # 改變 tensor 的形態
下面,本文簡單介紹一下 view() 和 size() 函式的作用:
view()
import torch # 用法一 a = torch.ones(2, 3, 4) b = a.view(3, 8) b Output: tensor([[1., 1., 1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1., 1., 1.]]) import torch # 用法二 a = torch.ones(2, 3, 4) b = a.view(4, -1) b Output: tensor([[1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1.]]) tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]) # b = a.view(-1)
以上展示了兩種用法:
1、torch.view(x, y, z, ……)
將原 tensor 以引數設定的維度重排
2、torch.view(x, -1) & torch.view(-1)
將原 tensor 以引數 x 設定第一維度重排,第二維度自動補齊;當沒有引數 x 時,直接重排為一維的 tensor
size()
import torch a = torch.ones(2, 3, 4) a.size() Output: torch.Size([2, 3, 4]) a.size(0) Output: 2 a.size(1) Output: 3 a.size(2) Output: 4
綜上,torch.size(x) 即返回 tensor第 x 維的長度