pytorch中clone()與detach()
阿新 • • 發佈:2021-12-10
detach()和clone()
Torch 為了提高速度,向量或是矩陣的賦值是指向同一記憶體的
如果需要開闢新的儲存地址而不是引用,可以用clone()進行深拷貝
a = torch.tensor([[1.,2.,3.],[4.,5.,6.]], requires_grad=True)
print(a)
b=a.detach()
print(b)
"""
tensor([[1., 2., 3.],
[4., 5., 6.]], requires_grad=True)
tensor([[1., 2., 3.],
[4., 5., 6.]])
"""
detach()操作後的tensor與原始tensor共享資料記憶體,當原始tensor在計算圖中數值發生反向傳播等更新之後,detach()的tensor值也發生了改變
a = torch.tensor([[1.,2.,3.],[4.,5.,6.]], requires_grad=True)
print(a)
b=a.clone()
print(b)
"""
tensor([[1., 2., 3.],
[4., 5., 6.]], requires_grad=True)
tensor([[1., 2., 3.],
[4., 5., 6.]], grad_fn=<CloneBackward>)
"""
grad_fn=<CloneBackward>表示clone後的返回值是個中間變數,因此支援梯度的回溯。
a = torch.tensor([[1.,2.,3.],[4.,5.,6.]], requires_grad=True)
print(a)
b=a.detach().clone()
print(b)
"""
tensor([[1., 2., 3.],
[4., 5., 6.]], requires_grad=True)
tensor([[1., 2., 3.],
[4., 5., 6.]])
"""
a = torch.tensor([[1.,2.,3.],[4.,5.,6.]], requires_grad=True) print(a) b=a.detach().clone().requires_grad_(True) print(b) """ tensor([[1., 2., 3.], [4., 5., 6.]], requires_grad=True) tensor([[1., 2., 3.], [4., 5., 6.]], requires_grad=True) """
clone()操作後的tensor requires_grad=True
detach()操作後的tensor requires_grad=False
import torch
torch.manual_seed(0)
x= torch.tensor([1., 2.], requires_grad=True)
clone_x = x.clone()
detach_x = x.detach()
clone_detach_x = x.clone().detach()
f = torch.nn.Linear(2, 1)
y = f(x)
y.backward()
print(x.grad)
print(clone_x.requires_grad)
print(clone_x.grad)
print(detach_x.requires_grad)
print(clone_detach_x.requires_grad)
'''
輸出結果如下:
tensor([-0.0053, 0.3793])
True
None
False
False
'''