1. 程式人生 > 其它 >tensor複製與儲存

tensor複製與儲存

1 tensor.clone()

clone(memory_format=torch.preserve_format)→ Tensor

返回tensor的拷貝,返回的新tensor和原來的tensor具有同樣的大小和資料型別。

  • 原tensor的requires_grad=True

clone()返回的tensor是中間節點,梯度會流向原tensor,即返回的tensor的梯度會疊加在原tensor上

>>> import torch
>>> a = torch.tensor(1.0, requires_grad=True)
>>> b = a.clone()
>>> id(a), id(b)  # a和b不是同一個物件
(140191154302240, 140191145593424)
>>> a.data_ptr(), b.data_ptr()  # 也不指向同一塊記憶體地址
(94724518544960, 94724519185792)
>>> a.requires_grad, b.requires_grad  # 但b的requires_grad屬性和a的一樣,同樣是True
(True, True)
>>> c = a * 2
>>> c.backward()
>>> a.grad
tensor(2.)
>>> d = b * 3
>>> d.backward()
>>> b.grad  # b的梯度值為None,因為是中間節點,梯度值不會被儲存
>>> a.grad  # b的梯度疊加在a上
tensor(5.)
  • 原tensor的requires_grad=False
>>> import torch
>>> a = torch.tensor(1.0)
>>> b = a.clone()
>>> id(a), id(b)  # a和b不是同一個物件
(140191169099168, 140191154762208)
>>> a.data_ptr(), b.data_ptr()  # 也不指向同一塊記憶體地址
(94724519502912, 94724519533952)
>>> a.requires_grad, b.requires_grad  # 但b的requires_grad屬性和a的一樣,同樣是False
(False, False)
>>> b.requires_grad_()
>>> c = b * 2
>>> c.backward()
>>> b.grad
tensor(2.)
>>> a.grad  # None

2 tensor.detach()

detach()

從計算圖中脫離出來。

返回一個新的tensor,新的tensor和原來的tensor共享資料記憶體,但不涉及梯度計算,即requires_grad=False。修改其中一個tensor的值,另一個也會改變,因為是共享同一塊記憶體,但如果對其中一個tensor執行某些內建操作,則會報錯,例如resize_、resize_as_、set_、transpose_。

>>> import torch
>>> a = torch.rand((3, 4), requires_grad=True)
>>> b = a.detach()
>>> id(a), id(b)  # a和b不是同一個物件了
(140191157657504, 140191161442944)
>>> a.data_ptr(), b.data_ptr()  # 但指向同一塊記憶體地址
(94724518609856, 94724518609856)
>>> a.requires_grad, b.requires_grad  # b的requires_grad為False
(True, False)
>>> b[0][0] = 1
>>> a[0][0]  # 修改b的值,a的值也會改變
tensor(1., grad_fn=<SelectBackward>)
>>> b.resize_((4, 3))  # 報錯
RuntimeError: set_sizes_contiguous is not allowed on a Tensor created from .data or .detach().

3. tensor.clone().detach() 還是 tensor.detach().clone()

兩者的結果是一樣的,即返回的tensor和原tensor在梯度上或者資料上沒有任何關係,一般用前者。