NDArray自動求導
阿新 • • 發佈:2017-11-04
def .cn pan turn logs int scala src blank
NDArray可以很方便的求解導數,比如下面的例子:(代碼主要參考自https://zh.gluon.ai/chapter_crashcourse/autograd.html)
用代碼實現如下:
1 import mxnet.ndarray as nd 2 import mxnet.autograd as ag 3 x = nd.array([[1,2],[3,4]]) 4 print(x) 5 x.attach_grad() #附加導數存放的空間 6 with ag.record(): 7 y = 2*x**2 8 y.backward() #求導9 z = x.grad #將導數結果(也是一個矩陣)賦值給z 10 print(z) #打印結果
[[ 1. 2.] [ 3. 4.]] <NDArray 2x2 @cpu(0)> [[ 4. 8.] [ 12. 16.]] <NDArray 2x2 @cpu(0)>
對控制流求導
NDArray還能對諸如if的控制分支進行求導,比如下面這段代碼:
1 def f(a): 2 if nd.sum(a).asscalar()<15: #如果矩陣a的元數和<15 3 b = a*2 #則所有元素*2 4 else: 5 b = a 6 return b
數學公式等價於:
這樣就轉換成本文最開頭示例一樣,變成單一函數求導,顯然導數值就是x前的常數項,驗證一下:
import mxnet.ndarray as nd import mxnet.autograd as ag def f(a): if nd.sum(a).asscalar()<15: #如果矩陣a的元數和<15 b = a*2 #則所有元素平方 else: b = a returnb #註:1+2+3+4<15,所以進入b=a*2的分支 x = nd.array([[1,2],[3,4]]) print("x1=") print(x) x.attach_grad() with ag.record(): y = f(x) print("y1=") print(y) y.backward() #dy/dx = y/x 即:2 print("x1.grad=") print(x.grad) x = x*2 print("x2=") print(x) x.attach_grad() with ag.record(): y = f(x) print("y2=") print(y) y.backward() print("x2.grad=") print(x.grad)
x1= [[ 1. 2.] [ 3. 4.]] <NDArray 2x2 @cpu(0)> y1= [[ 2. 4.] [ 6. 8.]] <NDArray 2x2 @cpu(0)> x1.grad= [[ 2. 2.] [ 2. 2.]] <NDArray 2x2 @cpu(0)> x2= [[ 2. 4.] [ 6. 8.]] <NDArray 2x2 @cpu(0)> y2= [[ 2. 4.] [ 6. 8.]] <NDArray 2x2 @cpu(0)> x2.grad= [[ 1. 1.] [ 1. 1.]] <NDArray 2x2 @cpu(0)>
NDArray自動求導