『PyTorch』第五彈_深入理解autograd_下:函數擴展&高階導數
一、封裝新的PyTorch函數
繼承Function類
forward:輸入Variable->中間計算Tensor->輸出Variable
backward:均使用Variable
線性映射
from torch.autograd import Function class MultiplyAdd(Function): # <----- 類需要繼承Function類 @staticmethod # <-----forward和backward都是靜態方法 def forward(ctx, w, x, b): # <-----ctx作為內部參數在前向反向傳播中協調 print(‘type in forward‘,type(x)) ctx.save_for_backward(w,x) # <-----ctx保存參數 output = w * x + b return output # <-----forward輸入參數和backward輸出參數必須一一對應 @staticmethod # <-----forward和backward都是靜態方法 def backward(ctx, grad_output): # <-----ctx作為內部參數在前向反向傳播中協調 w,x = ctx.saved_variables # <-----ctx讀取參數 print(‘type in backward‘,type(x)) grad_w = grad_output * x grad_x = grad_output * w grad_b = grad_output * 1 return grad_w, grad_x, grad_b # <-----backward輸入參數和forward輸出參數必須一一對應
調用方法一
類名.apply(參數)
輸出變量.backward()
import torch as t from torch.autograd import Variable as V x = V(t.ones(1)) w = V(t.rand(1), requires_grad = True) b = V(t.rand(1), requires_grad = True) print(‘開始前向傳播‘) z=MultiplyAdd.apply(w, x, b) # <-----forward print(‘開始反向傳播‘) z.backward() # 等效 # <-----backward # x不需要求導,中間過程還是會計算它的導數,但隨後被清空 print(x.grad, w.grad, b.grad)
開始前向傳播 type in forward <class ‘torch.FloatTensor‘> 開始反向傳播 type in backward <class ‘torch.autograd.variable.Variable‘>(None,
Variable containing: 1 [torch.FloatTensor of size 1],
Variable containing: 1 [torch.FloatTensor of size 1])
調用方法二
類名.apply(參數)
輸出變量.grad_fn.apply()
x = V(t.ones(1)) w = V(t.rand(1), requires_grad = True) b = V(t.rand(1), requires_grad = True) print(‘開始前向傳播‘) z=MultiplyAdd.apply(w,x,b) # <-----forward print(‘開始反向傳播‘) # 調用MultiplyAdd.backward # 會自動輸出grad_w, grad_x, grad_b z.grad_fn.apply(V(t.ones(1))) # <-----backward,在計算中間輸出,buffer並未清空,所以x的梯度不是None
開始前向傳播 type in forward <class ‘torch.FloatTensor‘> 開始反向傳播 type in backward <class ‘torch.autograd.variable.Variable‘>(Variable containing: 1 [torch.FloatTensor of size 1], Variable containing: 0.7655 [torch.FloatTensor of size 1], Variable containing: 1 [torch.FloatTensor of size 1])
之所以forward函數的輸入是tensor,而backward函數的輸入是variable,是為了實現高階求導。backward函數的輸入輸出雖然是variable,但在實際使用時autograd.Function會將輸入variable提取為tensor,並將計算結果的tensor封裝成variable返回。在backward函數中,之所以也要對variable進行操作,是為了能夠計算梯度的梯度(backward of backward)。下面舉例說明,有關torch.autograd.grad的更詳細使用請參照文檔。
二、高階導數
grad_x =t.autograd.grad(y, x, create_graph=True)
grad_grad_x = t.autograd.grad(grad_x[0],x)
x = V(t.Tensor([5]), requires_grad=True) y = x ** 2 grad_x = t.autograd.grad(y, x, create_graph=True) print(grad_x) # dy/dx = 2 * x grad_grad_x = t.autograd.grad(grad_x[0],x) print(grad_grad_x) # 二階導數 d(2x)/dx = 2
(Variable containing: 10 [torch.FloatTensor of size 1],)(Variable containing: 2 [torch.FloatTensor of size 1],)
三、梯度檢查
t.autograd.gradcheck(Sigmoid.apply, (test_input,), eps=1e-3)
此外在實現了自己的Function之後,還可以使用gradcheck
函數來檢測實現是否正確。gradcheck
通過數值逼近來計算梯度,可能具有一定的誤差,通過控制eps
的大小可以控制容忍的誤差。
class Sigmoid(Function): @staticmethod def forward(ctx, x): output = 1 / (1 + t.exp(-x)) ctx.save_for_backward(output) return output @staticmethod def backward(ctx, grad_output): output, = ctx.saved_variables grad_x = output * (1 - output) * grad_output return grad_x # 采用數值逼近方式檢驗計算梯度的公式對不對 test_input = V(t.randn(3,4), requires_grad=True) t.autograd.gradcheck(Sigmoid.apply, (test_input,), eps=1e-3)
True
測試效率,
def f_sigmoid(x): y = Sigmoid.apply(x) y.backward(t.ones(x.size())) def f_naive(x): y = 1/(1 + t.exp(-x)) y.backward(t.ones(x.size())) def f_th(x): y = t.sigmoid(x) y.backward(t.ones(x.size())) x=V(t.randn(100, 100), requires_grad=True) %timeit -n 100 f_sigmoid(x) %timeit -n 100 f_naive(x) %timeit -n 100 f_th(x)
實際測試結果,
245 μs ± 70.1 μs per loop (mean ± std. dev. of 7 runs, 100 loops each) 211 μs ± 23.3 μs per loop (mean ± std. dev. of 7 runs, 100 loops each) 219 μs ± 36.6 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
書中說的結果,
100 loops, best of 3: 320 μs per loop 100 loops, best of 3: 588 μs per loop 100 loops, best of 3: 271 μs per loop
很奇怪,我的結果竟然是:簡單堆砌<官方封裝<自己封裝……不過還是引用一下書中的結論吧:
顯然
f_sigmoid
要比單純利用autograd
加減和乘方操作實現的函數快不少,因為f_sigmoid的backward優化了反向傳播的過程。另外可以看出系統實現的buildin接口(t.sigmoid)更快。
『PyTorch』第五彈_深入理解autograd_下:函數擴展&高階導數