Dl4j-fit(DataSetIterator iterator)原始碼閱讀(四)dropout
preOut這一部分就是網路模型前向傳播的重點。
public INDArray preOutput(boolean training) {
applyDropOutIfNecessary(training);
INDArray b = getParam(DefaultParamInitializer.BIAS_KEY);
INDArray W = getParam(DefaultParamInitializer.WEIGHT_KEY);
//Input validation:
if (input.rank() != 2 || input.columns () != W.rows()) {
if (input.rank() != 2) {
throw new DL4JInvalidInputException("Input that is not a matrix; expected matrix (rank 2), got rank "
+ input.rank() + " array with shape " + Arrays.toString(input.shape()));
}
throw new DL4JInvalidInputException("Input size (" + input.columns() + " columns; shape = "
+ Arrays.toString(input.shape())
+ ") is invalid: does not match layer input size (layer # inputs = " + W.size(0) + ")");
}
if (conf.isUseDropConnect() && training && conf.getLayer().getDropOut () > 0) {
W = Dropout.applyDropConnect(this, DefaultParamInitializer.WEIGHT_KEY);
}
INDArray ret = input.mmul(W).addiRowVector(b);
if (maskArray != null) {
applyMask(ret);
}
return ret;
}
首先使用applyDropOutIfNecessary(training);
函式判斷當前是否使用dropout。
protected void applyDropOutIfNecessary(boolean training) {
if (conf.getLayer().getDropOut() > 0 && !conf.isUseDropConnect() && training && !dropoutApplied) {
input = input.dup();
Dropout.applyDropout(input, conf.getLayer().getDropOut());
dropoutApplied = true;
}
}
使用dropout的條件如下:
- 當前層設定 dropout > 0
- 當前配置沒有使用dropConnect(), 這一配置在卷積神經網路常見。
- 當前是訓練過程,也就是training的值為true。 在預測的時候dropout不會被應用
- dropout在之前沒有被呼叫。
如果以上條件都滿足,則先對當前的輸入使用dup()
函式進行復制(注:dup取自單詞duplicate,複製的意思),然後傳入下一個函式。
/**
5. Apply dropout to the given input
6. and return the drop out mask used
7. @param input the input to do drop out on
8. @param dropout the drop out probability
*/
public static void applyDropout(INDArray input, double dropout) {
if (Nd4j.getRandom().getStatePointer() != null) {
Nd4j.getExecutioner().exec(new DropOutInverted(input, dropout));
} else {
Nd4j.getExecutioner().exec(new LegacyDropOutInverted(input, dropout));
}
}
dropout的實現方式很多,根據這個原始碼閱讀方式發現,dl4j的dropout實現方式是根據截斷當前層的輸入來實現drpout。
/**
9. This method returns pointer to RNG state structure.
10. Please note: DefaultRandom implementation returns NULL here, making it impossible to use with RandomOps
11. - @return
*/
@Override
public Pointer getStatePointer() {
return statePointer;
}
這個getStatePointer()的目的從程式碼的註釋情況上來還不是很清楚。接下來檢視兩種實現方式
- DropOutInverted
/**
* Inverted DropOut implementation as Op
*
* @author [email protected]
*/
public class DropOutInverted extends BaseRandomOp {
private double p;
public DropOutInverted() {
}
public DropOutInverted(@NonNull INDArray x, double p) {
this(x, x, p, x.lengthLong());
}
public DropOutInverted(@NonNull INDArray x, @NonNull INDArray z, double p) {
this(x, z, p, x.lengthLong());
}
public DropOutInverted(@NonNull INDArray x, @NonNull INDArray z, double p, long n) {
this.p = p;
init(x, null, z, n);
}
@Override
public int opNum() {
return 2;
}
@Override
public String name() {
return "dropout_inverted";
}
@Override
public void init(INDArray x, INDArray y, INDArray z, long n) {
super.init(x, y, z, n);
this.extraArgs = new Object[] {p};
}
}
- LegacyDropOutInverted
/**
* Inverted DropOut implementation as Op
*
* PLEASE NOTE: This is legacy DropOutInverted implementation, please consider using op with the same name from randomOps
* @author [email protected]
*/
public class LegacyDropOutInverted extends BaseTransformOp {
private double p;
public LegacyDropOutInverted() {
}
public LegacyDropOutInverted(INDArray x, double p) {
super(x);
this.p = p;
init(x, null, x, x.length());
}
public LegacyDropOutInverted(INDArray x, INDArray z, double p) {
super(x, z);
this.p = p;
init(x, null, z, x.length());
}
public LegacyDropOutInverted(INDArray x, INDArray z, double p, long n) {
super(x, z, n);
this.p = p;
init(x, null, z, n);
}
@Override
public int opNum() {
return 44;
}
@Override
public String name() {
return "legacy_dropout_inverted";
}
@Override
public IComplexNumber op(IComplexNumber origin, double other) {
return null;
}
@Override
public IComplexNumber op(IComplexNumber origin, float other) {
return null;
}
@Override
public IComplexNumber op(IComplexNumber origin, IComplexNumber other) {
return null;
}
@Override
public float op(float origin, float other) {
return 0;
}
@Override
public double op(double origin, double other) {
return 0;
}
@Override
public double op(double origin) {
return 0;
}
@Override
public float op(float origin) {
return 0;
}
@Override
public IComplexNumber op(IComplexNumber origin) {
return null;
}
@Override
public Op opForDimension(int index, int dimension) {
INDArray xAlongDimension = x.vectorAlongDimension(index, dimension);
if (y() != null)
return new LegacyDropOutInverted(xAlongDimension, z.vectorAlongDimension(index, dimension), p,
xAlongDimension.length());
else
return new LegacyDropOutInverted(xAlongDimension, z.vectorAlongDimension(index, dimension), p,
xAlongDimension.length());
}
@Override
public Op opForDimension(int index, int... dimension) {
INDArray xAlongDimension = x.tensorAlongDimension(index, dimension);
if (y() != null)
return new LegacyDropOutInverted(xAlongDimension, z.tensorAlongDimension(index, dimension), p,
xAlongDimension.length());
else
return new LegacyDropOutInverted(xAlongDimension, z.tensorAlongDimension(index, dimension), p,
xAlongDimension.length());
}
@Override
public void init(INDArray x, INDArray y, INDArray z, long n) {
super.init(x, y, z, n);
this.extraArgs = new Object[] {p, (double) n};
}
}
這個dropout有些難以理解,這裡用單步的除錯資訊來檢視計算流程來嘗試理解:
當前程式執行的dropout的型別為DropOutInverted
。此時呼叫的函式如下:
public DropOutInverted(@NonNull INDArray x, double p) {
this(x, x, p, x.lengthLong());
}
當前輸入的x的值為:
[-10.0,-9.99,-9.98,-9.97,-9.96,-9.95,-9.94,-9.93,-9.92,-9.91,-9.9,-9.89,-9.88,-9.87,-9.86,-9.85,-9.84,-9.83,-9.82,-9.81]
它的shape為[20, 1],也就是一個 20 x 1的列向量。其中呼叫的x.lengthLong()
的值也為20。當前的p值也有改變,p值變為當前層的dropout值,即當前p = 0.5
。之後呼叫this執行到另外一個建構函式中:
public DropOutInverted(@NonNull INDArray x, @NonNull INDArray z, double p, long n) {
this.p = p;
init(x, null, z, n);
}
在呼叫到當前建構函式的時候,呼叫init函式,此時的 z和x是相同的值。
@Override
public void init(INDArray x, INDArray y, INDArray z, long n) {
super.init(x, y, z, n);
this.extraArgs = new Object[] {p};
}
執行到當前步,各項引數如下:
x = [-10.00, -9.99, -9.98, -9.97, -9.96, -9.95, -9.94, -9.93, -9.92, -9.91, -9.90, -9.89, -9.88, -9.87, -9.86, -9.85, -9.84, -9.83, -9.82, -9.81]
y = null
z = [-10.00, -9.99, -9.98, -9.97, -9.96, -9.95, -9.94, -9.93, -9.92, -9.91, -9.90, -9.89, -9.88, -9.87, -9.86, -9.85, -9.84, -9.83, -9.82, -9.81]
n = 20
p = 0.5
之後就會跳轉到父類的init()方法:
@Override
public void init(INDArray x, INDArray y, INDArray z, long n) {
this.x = x;
this.y = y;
this.z = z;
this.n = n;
}
父類方法只是對成員變數進行簡單賦值。
在以上變數初始化完成之後,繼續執行Nd4j.getExecutioner().exec(new DropOutInverted(input, dropout));
方法。
/**
* This method executes specified RandomOp using default RNG available via Nd4j.getRandom()
*
* @param op
*/
@Override
public INDArray exec(RandomOp op) {
return exec(op, Nd4j.getRandom());
}
根據註釋,兩個dropout類是特殊的RandomOp。之後繼續呼叫下一個exec()
方法。
/**
* This method executes specific
* RandomOp against specified RNG
*
* @param op
* @param rng
*/
@Override
public INDArray exec(RandomOp op, Random rng) {
if (rng.getStateBuffer() == null)
throw new IllegalStateException(
"You should use one of NativeRandom classes for NativeOperations execution");
long st = profilingHookIn(op);
validateDataType(Nd4j.dataType(), op);
if (op.x() != null && op.y() != null && op.z() != null) {
// triple arg call
if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
loop.execRandomFloat(null, op.opNum(), rng.getStatePointer(), // rng state ptr
(FloatPointer) op.x().data().addressPointer(),
(IntPointer) op.x().shapeInfoDataBuffer().addressPointer(),
(FloatPointer) op.y().data().addressPointer(),
(IntPointer) op.y().shapeInfoDataBuffer().addressPointer(),
(FloatPointer) op.z().data().addressPointer(),
(IntPointer) op.z().shapeInfoDataBuffer().addressPointer(),
(FloatPointer) op.extraArgsDataBuff().addressPointer());
} else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
loop.execRandomDouble(null, op.opNum(), rng.getStatePointer(), // rng state ptr
(DoublePointer) op.x().data().addressPointer(),
(IntPointer) op.x().shapeInfoDataBuffer().addressPointer(),
(DoublePointer) op.y().data().addressPointer(),
(IntPointer) op.y().shapeInfoDataBuffer().addressPointer(),
(DoublePointer) op.z().data().addressPointer(),
(IntPointer) op.z().shapeInfoDataBuffer().addressPointer(),
(DoublePointer) op.extraArgsDataBuff().addressPointer());
}
} else if (op.x() != null && op.z() != null) {
//double arg call
if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
loop.execRandomFloat(null, op.opNum(), rng.getStatePointer(), // rng state ptr
(FloatPointer) op.x().data().addressPointer(),
(IntPointer) op.x().shapeInfoDataBuffer().addressPointer(),
(FloatPointer) op.z().data().addressPointer(),
(IntPointer) op.z().shapeInfoDataBuffer().addressPointer(),
(FloatPointer) op.extraArgsDataBuff().addressPointer());
} else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
loop.execRandomDouble(null, op.opNum(), rng.getStatePointer(), // rng state ptr
(DoublePointer) op.x().data().addressPointer(),
(IntPointer) op.x().shapeInfoDataBuffer().addressPointer(),
(DoublePointer) op.z().data().addressPointer(),
(IntPointer) op.z().shapeInfoDataBuffer().addressPointer(),
(DoublePointer) op.extraArgsDataBuff().addressPointer());
}
} else {
// single arg call
if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
loop.execRandomFloat(null, op.opNum(), rng.getStatePointer(), // rng state ptr
(FloatPointer) op.z().data().addressPointer(),
(IntPointer) op.z().shapeInfoDataBuffer().addressPointer(),
(FloatPointer) op.extraArgsDataBuff().addressPointer());
} else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
loop.execRandomDouble(null, op.opNum(), rng.getStatePointer(), // rng state ptr
(DoublePointer) op.z().data().addressPointer(),
(IntPointer) op.z().shapeInfoDataBuffer().addressPointer(),
(DoublePointer) op.extraArgsDataBuff().addressPointer());
}
}
profilingHookOut(op, st);
return op.z();
}
這個函式首先使用validateDataType(Nd4j.dataType(), op);
用於檢驗當前資料型別的合法性。然後根據傳入的op的三個成員變數x, y, z來判斷進入哪一分支。在上面的debug資訊我們可以看到,我們的x和z是兩個非空變數,因此進入第二個分支,並且我們當前的Nd4j.dataType()
為DataBuffer.Type.FLOAT
。為此在當前環境下會執行以下語句:
loop.execRandomFloat(null, op.opNum(), rng.getStatePointer(), // rng state ptr
(FloatPointer) op.x().data().addressPointer(),
(IntPointer) op.x().shapeInfoDataBuffer().addressPointer(),
(FloatPointer) op.z().data().addressPointer(),
(IntPointer) op.z().shapeInfoDataBuffer().addressPointer(),
(FloatPointer) op.extraArgsDataBuff().addressPointer());
然後這部分的具體實現應該是JNI呼叫的底層
public native void execRandomFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, @Cast("Nd4jPointer") Pointer state, FloatPointer x, IntPointer xShapeBuffer, FloatPointer z, IntPointer zShapeBuffer, FloatPointer extraArguments);
經過如上方法的的執行之後,返回z值,這時候通過debug資訊看到的z值為:
[-20.00, -19.98, -19.96, 0.00, 0.00, -19.90, -19.88, 0.00, -19.84, -19.82, 0.00, 0.00, -19.76, -19.74, -19.72, -19.70, -19.68, -19.66, -19.64, 0.00]
因為在前面輸入的時候z其實和x是等同的。在執行以上方法之後,相當於對x做了一個變幻。使得x變為如上的數值。到這裡就使得dl4j的applyDropOutIfNecessary(training)
方法部分完成,繼續回到preOutput()
方法體內繼續執行。(這裡猜測實現的方式是部分位置隨機置0,然後再所有的資料除以dropout的值)
相關推薦
Dl4j-fit(DataSetIterator iterator)原始碼閱讀(四)dropout
preOut這一部分就是網路模型前向傳播的重點。 public INDArray preOutput(boolean training) { applyDropOutIfNecessary(training); INDArray b = g
DL4J原始碼閱讀(四):梯度計算
computeGradientAndScore方法呼叫backprop()做梯度計算和誤差反傳。 backprop()呼叫calcBackpropGradients()方法。calcBackpropGradients()方法再呼叫initGradie
Redis原始碼閱讀(四)叢集-請求分配
叢集搭建好之後,使用者傳送的命令請求可以被分配到不同的節點去處理。那Redis對命令請求分配的依據是什麼?如果節點數量有變動,命令又是如何重新分配的,重分配的過程是否會阻塞對外提供的服務?接下來會從這兩個問題入手,分析Redis3.0的原始碼實現。 1. 分配依據—
jQuery原始碼閱讀(四)--正則表示式
在jQuery原始碼中,運用了大量的正則表示式,一開始在看的時候真的是一頭霧水,儘管已經看過了JS高程裡面的正則表示式。 今天,看了一篇深入理解正則表示式的文章,對正則表示式有了更深的認識,下面做一個回顧和總結。 正則表示式基礎 JS正則表示式用來匹配
Horizon 原始碼閱讀(四)—— 呼叫Novaclient流程
一、寫在前面 這篇文章主要介紹一下OpenStack(Kilo)關於horizon 呼叫NovaClient的一個分析。 如果轉載,請保留作者資訊。 原文地址:http://blog.csdn.net/u011521019/a
XSStrike原始碼閱讀(2)——四種模式
1.bruteforcer模式 功能介紹 根據使用者提供的payloads檔案去暴力測試每一個引數,以此來確定是否存在xss漏洞(說起來也就是一個兩層迴圈)。 具體實現 XSStrike3.0 bruteforcer.py原始碼如下: import copy from
DL4J原始碼閱讀(七):LSTM梯度計算
LSTMHelpers類中的backpropGradientHelper方法是梯度計算過程。 // 本層神經元個數 int hiddenLayerSize = recurrentWeights.size(0); //i.e., n^
LevelDB的源碼閱讀(四) Compaction操作
left 維護 efault smallest item app apply() body roc leveldb的數據存儲采用LSM的思想,將隨機寫入變為順序寫入,記錄寫入操作日誌,一旦日誌被以追加寫的形式寫入硬盤,就返回寫入成功,由後臺線程將寫入日誌作用於原有的磁盤文件
Flume NG原始碼分析(四)使用ExecSource從本地日誌檔案中收集日誌
常見的日誌收集方式有兩種,一種是經由本地日誌檔案做媒介,非同步地傳送到遠端日誌倉庫,一種是基於RPC方式的同步日誌收集,直接傳送到遠端日誌倉庫。這篇講講Flume NG如何從本地日誌檔案中收集日誌。 ExecSource是用來執行本地shell命令,並把本地日誌檔案中的資料封裝成Event
OpenCV學習筆記(30)KAZE 演算法原理與原始碼分析(四)KAZE特徵的效能分析與比較
KAZE系列筆記: 1. OpenCV學習筆記(27)KAZE 演算法原理與原始碼分析(一)非線性擴散濾波 2. OpenCV學習筆記(28)KAZE 演算法原理與原始碼分析(二)非線性尺度空間構
GCC原始碼分析(四)——優化
原文連結:http://blog.csdn.net/sonicling/article/details/7916931 一、前言 本篇只介紹一下框架,就不具體介紹每個步驟了。 二、Pass框架 上一篇已經講了gcc的中間語言的表現形式。gcc 對中間語言
Spring原始碼解析(四)——元件註冊4
/** * 給容器中註冊元件; * 1)、包掃描+元件標註註解(@Controller/@Service/@Repository/@Component)[自己寫的類] * 2)、@Bean[匯入的第三方包裡面的元件] * 3)、@Import[快速給容器中匯入一個
【筆記】ThreadPoolExecutor原始碼閱讀(三)
執行緒數量的維護 執行緒池的大小有兩個重要的引數,一個是corePoolSize(核心執行緒池大小),另一個是maximumPoolSize(最大執行緒大小)。執行緒池主要根據這兩個引數對執行緒池中執行緒的數量進行維護。 需要注意的是,執行緒池建立之初是沒有任何可用執行緒的。只有在有任務到達後,才開始建立
YOLOv2原始碼分析(四)
文章全部YOLOv2原始碼分析 0x01 backward_convolutional_layer void backward_convolutional_layer(convolutional_layer l, network
## Zookeeper原始碼閱讀(六) Watcher
前言 好久沒有更新部落格了,最近這段時間過得很壓抑,終於開始踏上為換工作準備的正軌了,工作又真的很忙而且很瑣碎,讓自己有點煩惱,希望能早點結束這種狀態。 繼上次分析了ZK的ACL相關程式碼後,ZK裡非常重要的另一個特性就是Watcher機制了。其實在我看來,就ZK的使用而言,Watche機制是最核心的特性
Zookeeper原始碼閱讀(七) Server端Watcher
前言 前面一篇主要介紹了Watcher介面相關的介面和實體類,但是主要是zk客戶端相關的程式碼,如前一篇開頭所說,client需要把watcher註冊到server端,這一篇分析下server端的watcher。 主要分析Watchmanager類。 Watchmanager 這是WatchMan
Zookeeper原始碼閱讀(五) ACL基礎
前言 之前看程式碼的時候也同步看了看一些關於zk原始碼的部落格,有一兩篇講到了ZK裡ACL的基礎的結構,我自己這邊也看了看相關的程式碼,在這裡分享一下! ACL和ID ACL和ID都是有Jute生成的實體類,分別代表了ZK裡ACL和不同ACL模式下的具體實體。 ACL: public class A
Redis原始碼閱讀(六)叢集-故障遷移(下)
Redis原始碼閱讀(六)叢集-故障遷移(下) 最近私人的事情比較多,沒有抽出時間來整理部落格。書接上文,上一篇裡總結了Redis故障遷移的幾個關鍵點,以及Redis中故障檢測的實現。本篇主要介紹叢集檢測到某主節點下線後,是如何選舉新的主節點的。注意到Redis叢集是無中心的,那麼使用分散式一
AFNetWorking(3.0)原始碼分析(四)——AFHTTPSessionManager(2)
在上一篇部落格中,我們分析了AFHTTPSessionManager,以及它是如何實現GET/HEAD/PATCH/DELETE相關介面的。 我們還剩下POST相關介面沒有分析,在這篇部落格裡面,我們就來分析一下POST相關介面是如何實現的。 multipart/form-data請
SpringBoot2.0原始碼分析(四):spring-data-jpa分析
SpringBoot具體整合rabbitMQ可參考:SpringBoot2.0應用(四):SpringBoot2.0之spring-data-jpa JpaRepositories自動注入 當專案中存在org.springframework.data.jpa.repository.JpaRepositor