在pytorch中對非葉節點的變數計算梯度例項
阿新 • • 發佈:2020-01-13
在pytorch中一般只對葉節點進行梯度計算,也就是下圖中的d,e節點,而對非葉節點,也即是c,b節點則沒有顯式地去保留其中間計算過程中的梯度(因為一般來說只有葉節點才需要去更新),這樣可以節省很大部分的視訊記憶體,但是在除錯過程中,有時候我們需要對中間變數梯度進行監控,以確保網路的有效性,這個時候我們需要打印出非葉節點的梯度,為了實現這個目的,我們可以通過兩種手段進行。
註冊hook函式
Tensor.register_hook[2] 可以註冊一個反向梯度傳導時的hook函式,這個hook函式將會在每次計算 關於該張量 的時候 被呼叫,經常用於除錯的時候打印出非葉節點梯度。當然,通過這個手段,你也可以自定義某一層的梯度更新方法。[3] 具體到這裡的列印非葉節點的梯度,程式碼如:
def hook_y(grad): print(grad) x = Variable(torch.ones(2,2),requires_grad=True) y = x + 2 z = y * y * 3 y.register_hook(hook_y) out = z.mean() out.backward()
輸出如:
tensor([[4.5000,4.5000],[4.5000,4.5000]])
retain_grad()
Tensor.retain_grad()顯式地儲存非葉節點的梯度,當然代價就是會增加視訊記憶體的消耗,而用hook函式的方法則是在反向計算時直接列印,因此不會增加視訊記憶體消耗,但是使用起來retain_grad()要比hook函式方便一些。程式碼如:
x = Variable(torch.ones(2,requires_grad=True) y = x + 2 y.retain_grad() z = y * y * 3 out = z.mean() out.backward() print(y.grad)
輸出如:
tensor([[4.5000,4.5000]])
以上這篇在pytorch中對非葉節點的變數計算梯度例項就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。