1. 程式人生 > 其它 >8 Torch 中 view() & size() 用法

8 Torch 中 view() & size() 用法

Torch 中 view() & size()

在閱讀論文原始碼過程中,經常會看到如下的命令:

x = x.view(x.size(0), -1)	# 改變 tensor 的形態

下面,本文簡單介紹一下 view() 和 size() 函式的作用:

view()

import torch	# 用法一
a = torch.ones(2, 3, 4)
b = a.view(3, 8)
b

Output:
tensor([[1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

import torch	# 用法二
a = torch.ones(2, 3, 4)
b = a.view(4, -1)
b

Output:
tensor([[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]])
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])	# b = a.view(-1)

以上展示了兩種用法:

1、torch.view(x, y, z, ……)

將原 tensor 以引數設定的維度重排

2、torch.view(x, -1) & torch.view(-1)

將原 tensor 以引數 x 設定第一維度重排,第二維度自動補齊;當沒有引數 x 時,直接重排為一維的 tensor

size()

import torch
a = torch.ones(2, 3, 4)
a.size()
Output:
torch.Size([2, 3, 4])

a.size(0)
Output:
2

a.size(1)
Output:
3

a.size(2)
Output:
4

綜上,torch.size(x) 即返回 tensor第 x 維的長度