1. 程式人生 > >ND4J自動微分

ND4J自動微分

一、前言

    ND4J從beta2開始就開始支援自動微分,不過直到beta4版本為止,自動微分還只支援CPU,GPU版本將在後續版本中實現。

    本篇部落格中,我們將用ND4J來構建一個函式,利用ND4J SameDiff構建函式求函式值和求函式每個變數的偏微分值。

二、構建函式

    構建函式和分別手動求偏導數

    

    給定一個點(2,3)手動求函式值和偏導,計算如下:

    f=2+3*4+3=17,f對x的偏導:1+2*2*3=13,f對y的偏導:4+1=5

三、通過ND4J自動微分來求

    完整程式碼

package org.nd4j.samediff;

import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.factory.Nd4j;

/**
 * 
 * x+y*x2+y
 *
 */
public class Function {

	public static void main(String[] args) {
		//構建SameDiff例項
		SameDiff sd=SameDiff.create();
		//建立變數x、y
		SDVariable x= sd.var("x");
		SDVariable y=sd.var("y");
		
		//定義函式
		SDVariable f=x.add(y.mul(sd.math().pow(x, 2)));
		f.add("addY",y);
		
		//給變數x、y繫結具體值
		x.setArray(Nd4j.create(new double[]{2}));
		y.setArray(Nd4j.create(new double[]{3}));
		//前向計算函式的值
		System.out.println(sd.exec(null, "addY").get("addY"));
		//後向計算求梯度
		sd.execBackwards(null);
		//列印x在(2,3)處的導數
		System.out.println(sd.getGradForVariable("x").getArr());
		//x.getGradient().getArr()和sd.getGradForVariable("x").getArr()等效
		System.out.println(x.getGradient().getArr());
		//列印y在(2,3)處的導數
		System.out.println(sd.getGradForVariable("y").getArr());
	}
}

    四、執行結果

o.n.l.f.Nd4jBackend - Loaded [CpuBackend] backend
o.n.n.NativeOpsHolder - Number of threads used for NativeOps: 4
o.n.n.Nd4jBlas - Number of threads used for BLAS: 4
o.n.l.a.o.e.DefaultOpExecutioner - Backend used: [CPU]; OS: [Windows 10]
o.n.l.a.o.e.DefaultOpExecutioner - Cores: [8]; Memory: [3.2GB];
o.n.l.a.o.e.DefaultOpExecutioner - Blas vendor: [MKL]
17.0000
o.n.a.s.SameDiff - Inferring output "addY" as loss variable as none were previously set. Use SameDiff.setLossVariables() to override
13.0000
13.0000
5.0000

    結果為17、13、5和手動求出的結果完全一致。

    自動微分遮蔽了deeplearning在求微分過程中的很多細節,特別是矩陣求導、矩陣範數求導等等,是非常麻煩的,用自動微分,可以輕鬆實現各式各樣的網路結構。

 

快樂源於分享。

   此部落格乃作者原創,