1. 程式人生 > 其它 >pytorch中clone()與detach()

pytorch中clone()與detach()

detach()和clone()

Torch 為了提高速度,向量或是矩陣的賦值是指向同一記憶體的
如果需要開闢新的儲存地址而不是引用,可以用clone()進行深拷貝

a = torch.tensor([[1.,2.,3.],[4.,5.,6.]], requires_grad=True)
print(a)
b=a.detach()
print(b)
"""
tensor([[1., 2., 3.],
        [4., 5., 6.]], requires_grad=True)
tensor([[1., 2., 3.],
        [4., 5., 6.]])
"""

detach()操作後的tensor與原始tensor共享資料記憶體,當原始tensor在計算圖中數值發生反向傳播等更新之後,detach()的tensor值也發生了改變

a = torch.tensor([[1.,2.,3.],[4.,5.,6.]], requires_grad=True)
print(a)
b=a.clone()
print(b)
"""
tensor([[1., 2., 3.],
        [4., 5., 6.]], requires_grad=True)
tensor([[1., 2., 3.],
        [4., 5., 6.]], grad_fn=<CloneBackward>)
"""

grad_fn=<CloneBackward>表示clone後的返回值是個中間變數,因此支援梯度的回溯。

a = torch.tensor([[1.,2.,3.],[4.,5.,6.]], requires_grad=True)
print(a)
b=a.detach().clone()
print(b)
"""
tensor([[1., 2., 3.],
        [4., 5., 6.]], requires_grad=True)
tensor([[1., 2., 3.],
        [4., 5., 6.]])
"""
a = torch.tensor([[1.,2.,3.],[4.,5.,6.]], requires_grad=True)
print(a)
b=a.detach().clone().requires_grad_(True)
print(b)
"""
tensor([[1., 2., 3.],
        [4., 5., 6.]], requires_grad=True)
tensor([[1., 2., 3.],
        [4., 5., 6.]], requires_grad=True)
"""

clone()操作後的tensor requires_grad=True
detach()操作後的tensor requires_grad=False

import torch
torch.manual_seed(0)

x= torch.tensor([1., 2.], requires_grad=True)
clone_x = x.clone() 
detach_x = x.detach()
clone_detach_x = x.clone().detach() 

f = torch.nn.Linear(2, 1)
y = f(x)
y.backward()

print(x.grad)
print(clone_x.requires_grad)
print(clone_x.grad)
print(detach_x.requires_grad)
print(clone_detach_x.requires_grad)
'''
輸出結果如下:
tensor([-0.0053,  0.3793])
True
None
False
False
'''