1. 程式人生 > >NDArray自動求導

NDArray自動求導

def .cn pan turn logs int scala src blank

NDArray可以很方便的求解導數,比如下面的例子:(代碼主要參考自https://zh.gluon.ai/chapter_crashcourse/autograd.html)

技術分享

用代碼實現如下:

 1 import mxnet.ndarray as nd
 2 import mxnet.autograd as ag
 3 x = nd.array([[1,2],[3,4]])
 4 print(x)
 5 x.attach_grad() #附加導數存放的空間
 6 with ag.record():
 7     y = 2*x**2
 8 y.backward() #求導
9 z = x.grad #將導數結果(也是一個矩陣)賦值給z 10 print(z) #打印結果
[[ 1.  2.]
 [ 3.  4.]]
<NDArray 2x2 @cpu(0)>

[[  4.   8.]
 [ 12.  16.]]
<NDArray 2x2 @cpu(0)>

對控制流求導

NDArray還能對諸如if的控制分支進行求導,比如下面這段代碼:

1 def f(a):
2     if nd.sum(a).asscalar()<15: #如果矩陣a的元數和<15
3         b = a*2 #
則所有元素*2 4 else: 5 b = a 6 return b

數學公式等價於:

技術分享

這樣就轉換成本文最開頭示例一樣,變成單一函數求導,顯然導數值就是x前的常數項,驗證一下:

import mxnet.ndarray as nd
import mxnet.autograd as ag

def f(a):
    if nd.sum(a).asscalar()<15: #如果矩陣a的元數和<15
        b = a*2 #則所有元素平方
    else:
        b = a 
    return
b #註:1+2+3+4<15,所以進入b=a*2的分支 x = nd.array([[1,2],[3,4]]) print("x1=") print(x) x.attach_grad() with ag.record(): y = f(x) print("y1=") print(y) y.backward() #dy/dx = y/x 即:2 print("x1.grad=") print(x.grad) x = x*2 print("x2=") print(x) x.attach_grad() with ag.record(): y = f(x) print("y2=") print(y) y.backward() print("x2.grad=") print(x.grad)
x1=

[[ 1.  2.]
 [ 3.  4.]]
<NDArray 2x2 @cpu(0)>
y1=

[[ 2.  4.]
 [ 6.  8.]]
<NDArray 2x2 @cpu(0)>
x1.grad=

[[ 2.  2.]
 [ 2.  2.]]
<NDArray 2x2 @cpu(0)>
x2=

[[ 2.  4.]
 [ 6.  8.]]
<NDArray 2x2 @cpu(0)>
y2=

[[ 2.  4.]
 [ 6.  8.]]
<NDArray 2x2 @cpu(0)>
x2.grad=

[[ 1.  1.]
 [ 1.  1.]]
<NDArray 2x2 @cpu(0)>

NDArray自動求導