1. 程式人生 > 程式設計 >PyTorch中clone()、detach()及相關擴充套件詳解

PyTorch中clone()、detach()及相關擴充套件詳解

clone() 與 detach() 對比

Torch 為了提高速度,向量或是矩陣的賦值是指向同一記憶體的,這不同於 Matlab。如果需要儲存舊的tensor即需要開闢新的儲存地址而不是引用,可以用 clone() 進行深拷貝,

首先我們來打印出來clone()操作後的資料型別定義變化:

(1). 簡單列印型別

import torch

a = torch.tensor(1.0,requires_grad=True)
b = a.clone()
c = a.detach()
a.data *= 3
b += 1

print(a) # tensor(3.,requires_grad=True)
print(b)
print(c)

'''
輸出結果:
tensor(3.,requires_grad=True)
tensor(2.,grad_fn=<AddBackward0>)
tensor(3.) # detach()後的值隨著a的變化出現變化
'''

grad_fn=<CloneBackward>,表示clone後的返回值是個中間變數,因此支援梯度的回溯。clone操作在一定程度上可以視為是一個identity-mapping函式。

detach()操作後的tensor與原始tensor共享資料記憶體,當原始tensor在計算圖中數值發生反向傳播等更新之後,detach()的tensor值也發生了改變。

注意: 在pytorch中我們不要直接使用id是否相等來判斷tensor是否共享記憶體,這只是充分條件,因為也許底層共享資料記憶體,但是仍然是新的tensor,比如detach(),如果我們直接列印id會出現以下情況。

import torch as t
a = t.tensor([1.0,2.0],requires_grad=True)
b = a.detach()
#c[:] = a.detach()
print(id(a))
print(id(b))
#140568935450520
140570337203616

顯然直接打印出來的id不等,我們可以通過簡單的賦值後觀察資料變化進行判斷。

(2). clone()的梯度回傳

detach()函式可以返回一個完全相同的tensor,與舊的tensor共享記憶體,脫離計算圖,不會牽扯梯度計算。

而clone充當中間變數,會將梯度傳給源張量進行疊加,但是本身不儲存其grad,即值為None

import torch
a = torch.tensor(1.0,requires_grad=True)
a_ = a.clone()
y = a**2
z = a ** 2+a_ * 3
y.backward()
print(a.grad) # 2
z.backward()
print(a_.grad)   # None. 中間variable,無grad
print(a.grad) 
'''
輸出:
tensor(2.) 
None
tensor(7.) # 2*2+3=7
'''

使用torch.clone()獲得的新tensor和原來的資料不再共享記憶體,但仍保留在計算圖中,clone操作在不共享資料記憶體的同時支援梯度梯度傳遞與疊加,所以常用在神經網路中某個單元需要重複使用的場景下。

通常如果原tensor的requires_grad=True,則:

  • clone()操作後的tensor requires_grad=True
  • detach()操作後的tensor requires_grad=False。
import torch
torch.manual_seed(0)

x= torch.tensor([1.,2.],requires_grad=True)
clone_x = x.clone() 
detach_x = x.detach()
clone_detach_x = x.clone().detach() 

f = torch.nn.Linear(2,1)
y = f(x)
y.backward()

print(x.grad)
print(clone_x.requires_grad)
print(clone_x.grad)
print(detach_x.requires_grad)
print(clone_detach_x.requires_grad)
'''
輸出結果如下:
tensor([-0.0053,0.3793])
True
None
False
False
'''

另一個比較特殊的是當源張量的 require_grad=False,clone後的張量 require_grad=True,此時不存在張量回傳現象,可以得到clone後的張量求導。

如下:

import torch
a = torch.tensor(1.0)
a_ = a.clone()
a_.requires_grad_() #require_grad=True
y = a_ ** 2
y.backward()
print(a.grad) # None
print(a_.grad) 
'''
輸出:
None
tensor(2.)
'''

瞭解了兩者的區別後我們常與其他函式進行搭配使用,實現資料拷貝後的其他需要。

比如我們經常使用view()函式對tensor進行reshape操作。返回的新Tensor與源Tensor可能有不同的size,但是是共享data的,即其中的一個發生變化,另外一個也會跟著改變。

需要注意的是view返回的Tensor與源Tensor是共享data的,但是依然是一個新的Tensor(因為Tensor除了包含data外還有一些其他屬性),兩者id(記憶體地址)並不一致。

x = torch.rand(2,2)
y = x.view(4)
x += 1
print(x)
print(y) # 也加了1

view() 僅僅是改變了對這個張量的觀察角度,內部資料並未改變。這時候想返回一個真正新的副本(即不共享data記憶體)該怎麼辦呢?Pytorch還提供了一個reshape()可以改變形狀,但是此函式並不能保證返回的是其拷貝,所以不推薦使用。推薦先用clone創造一個副本然後再使用view。參考此處

x = torch.rand(2,2)
x_cp = x.clone().view(4)
x += 1
print(id(x))
print(id(x_cp))
print(x)
print(x_cp)
'''
140568935036464
140568935035816
tensor([[0.4963,0.7682],[0.1320,0.3074]])
tensor([[1.4963,1.7682,1.1320,1.3074]]) 
'''

另外使用clone()會被記錄在計算圖中,即梯度回傳到副本時也會傳到源Tensor。在上一篇中有總結。

總結:

  • torch.detach() — 新的tensor會脫離計算圖,不會牽扯梯度計算
  • torch.clone() — 新的tensor充當中間變數,會保留在計算圖中,參與梯度計算(回傳疊加),但是一般不會保留自身梯度。
    原地操作(in-place,such as resize_ / resize_as_ / set_ / transpose_) 在上面兩者中執行都會引發錯誤或者警告。
  • 共享資料記憶體是底層設計,並不能簡單的通過直接列印tensor的id地址進行判斷,需要在進行賦值或運算操作後列印比較資料的變化進行判斷。
  • 複製操作可以根據實際需要進行結合使用。

引用官方文件的話:如果你使用了in-place operation而沒有報錯的話,那麼你可以確定你的梯度計算是正確的。另外儘量避免in-place的使用。

像y = x + y這樣的運算會新開記憶體,然後將y指向新記憶體。我們可以使用Python自帶的id函式進行驗證:如果兩個例項的ID相同,則它們所對應的記憶體地址相同。

到此這篇關於PyTorch中clone()、detach()及相關擴充套件詳解的文章就介紹到這了,更多相關PyTorch中clone()、detach()及相關擴充套件內容請搜尋我們以前的文章或繼續瀏覽下面的相關文章希望大家以後多多支援我們!