1. 程式人生 > >Dl4j-fit(DataSetIterator iterator)原始碼閱讀(四)dropout

Dl4j-fit(DataSetIterator iterator)原始碼閱讀(四)dropout

preOut這一部分就是網路模型前向傳播的重點。

public INDArray preOutput(boolean training) {
    applyDropOutIfNecessary(training);
    INDArray b = getParam(DefaultParamInitializer.BIAS_KEY);
    INDArray W = getParam(DefaultParamInitializer.WEIGHT_KEY);

    //Input validation:
    if (input.rank() != 2 || input.columns
() != W.rows()) { if (input.rank() != 2) { throw new DL4JInvalidInputException("Input that is not a matrix; expected matrix (rank 2), got rank " + input.rank() + " array with shape " + Arrays.toString(input.shape())); } throw new DL4JInvalidInputException("Input size ("
+ input.columns() + " columns; shape = " + Arrays.toString(input.shape()) + ") is invalid: does not match layer input size (layer # inputs = " + W.size(0) + ")"); } if (conf.isUseDropConnect() && training && conf.getLayer().getDropOut
() > 0) { W = Dropout.applyDropConnect(this, DefaultParamInitializer.WEIGHT_KEY); } INDArray ret = input.mmul(W).addiRowVector(b); if (maskArray != null) { applyMask(ret); } return ret; }

首先使用applyDropOutIfNecessary(training);函式判斷當前是否使用dropout。

protected void applyDropOutIfNecessary(boolean training) {
    if (conf.getLayer().getDropOut() > 0 && !conf.isUseDropConnect() && training && !dropoutApplied) {
        input = input.dup();
        Dropout.applyDropout(input, conf.getLayer().getDropOut());
        dropoutApplied = true;
    }
}

使用dropout的條件如下:

  1. 當前層設定 dropout > 0
  2. 當前配置沒有使用dropConnect(), 這一配置在卷積神經網路常見。
  3. 當前是訓練過程,也就是training的值為true。 在預測的時候dropout不會被應用
  4. dropout在之前沒有被呼叫。

如果以上條件都滿足,則先對當前的輸入使用dup()函式進行復制(注:dup取自單詞duplicate,複製的意思),然後傳入下一個函式。

/**
 5. Apply dropout to the given input
 6. and return the drop out mask used
 7. @param input the input to do drop out on
 8. @param dropout the drop out probability
 */
public static void applyDropout(INDArray input, double dropout) {
    if (Nd4j.getRandom().getStatePointer() != null) {
        Nd4j.getExecutioner().exec(new DropOutInverted(input, dropout));
    } else {
        Nd4j.getExecutioner().exec(new LegacyDropOutInverted(input, dropout));
    }
}

dropout的實現方式很多,根據這個原始碼閱讀方式發現,dl4j的dropout實現方式是根據截斷當前層的輸入來實現drpout。

/**
 9. This method returns pointer to RNG state structure.
 10. Please note: DefaultRandom implementation returns NULL here, making it impossible to use with RandomOps
 11.  - @return
 */
@Override
public Pointer getStatePointer() {
    return statePointer;
}

這個getStatePointer()的目的從程式碼的註釋情況上來還不是很清楚。接下來檢視兩種實現方式

  1. DropOutInverted
/**
 * Inverted DropOut implementation as Op
 *
 * @author [email protected]
 */
public class DropOutInverted extends BaseRandomOp {

    private double p;

    public DropOutInverted() {

    }

    public DropOutInverted(@NonNull INDArray x, double p) {
        this(x, x, p, x.lengthLong());
    }

    public DropOutInverted(@NonNull INDArray x, @NonNull INDArray z, double p) {
        this(x, z, p, x.lengthLong());
    }

    public DropOutInverted(@NonNull INDArray x, @NonNull INDArray z, double p, long n) {
        this.p = p;
        init(x, null, z, n);
    }

    @Override
    public int opNum() {
        return 2;
    }

    @Override
    public String name() {
        return "dropout_inverted";
    }

    @Override
    public void init(INDArray x, INDArray y, INDArray z, long n) {
        super.init(x, y, z, n);
        this.extraArgs = new Object[] {p};
    }
}
  1. LegacyDropOutInverted
/**
 * Inverted DropOut implementation as Op
 *
 * PLEASE NOTE: This is legacy DropOutInverted implementation, please consider using op with the same name from randomOps
 * @author [email protected]
 */
public class LegacyDropOutInverted extends BaseTransformOp {

    private double p;

    public LegacyDropOutInverted() {

    }

    public LegacyDropOutInverted(INDArray x, double p) {
        super(x);
        this.p = p;
        init(x, null, x, x.length());
    }

    public LegacyDropOutInverted(INDArray x, INDArray z, double p) {
        super(x, z);
        this.p = p;
        init(x, null, z, x.length());
    }

    public LegacyDropOutInverted(INDArray x, INDArray z, double p, long n) {
        super(x, z, n);
        this.p = p;
        init(x, null, z, n);
    }

    @Override
    public int opNum() {
        return 44;
    }

    @Override
    public String name() {
        return "legacy_dropout_inverted";
    }

    @Override
    public IComplexNumber op(IComplexNumber origin, double other) {
        return null;
    }

    @Override
    public IComplexNumber op(IComplexNumber origin, float other) {
        return null;
    }

    @Override
    public IComplexNumber op(IComplexNumber origin, IComplexNumber other) {
        return null;
    }

    @Override
    public float op(float origin, float other) {
        return 0;
    }

    @Override
    public double op(double origin, double other) {
        return 0;
    }

    @Override
    public double op(double origin) {
        return 0;
    }

    @Override
    public float op(float origin) {
        return 0;
    }

    @Override
    public IComplexNumber op(IComplexNumber origin) {
        return null;

    }

    @Override
    public Op opForDimension(int index, int dimension) {
        INDArray xAlongDimension = x.vectorAlongDimension(index, dimension);

        if (y() != null)
            return new LegacyDropOutInverted(xAlongDimension, z.vectorAlongDimension(index, dimension), p,
                            xAlongDimension.length());
        else
            return new LegacyDropOutInverted(xAlongDimension, z.vectorAlongDimension(index, dimension), p,
                            xAlongDimension.length());
    }

    @Override
    public Op opForDimension(int index, int... dimension) {
        INDArray xAlongDimension = x.tensorAlongDimension(index, dimension);

        if (y() != null)
            return new LegacyDropOutInverted(xAlongDimension, z.tensorAlongDimension(index, dimension), p,
                            xAlongDimension.length());
        else
            return new LegacyDropOutInverted(xAlongDimension, z.tensorAlongDimension(index, dimension), p,
                            xAlongDimension.length());

    }

    @Override
    public void init(INDArray x, INDArray y, INDArray z, long n) {
        super.init(x, y, z, n);
        this.extraArgs = new Object[] {p, (double) n};
    }
}

這個dropout有些難以理解,這裡用單步的除錯資訊來檢視計算流程來嘗試理解:
當前程式執行的dropout的型別為DropOutInverted。此時呼叫的函式如下:

public DropOutInverted(@NonNull INDArray x, double p) {
    this(x, x, p, x.lengthLong());
}

當前輸入的x的值為:
[-10.0,-9.99,-9.98,-9.97,-9.96,-9.95,-9.94,-9.93,-9.92,-9.91,-9.9,-9.89,-9.88,-9.87,-9.86,-9.85,-9.84,-9.83,-9.82,-9.81]
它的shape為[20, 1],也就是一個 20 x 1的列向量。其中呼叫的x.lengthLong()的值也為20。當前的p值也有改變,p值變為當前層的dropout值,即當前p = 0.5。之後呼叫this執行到另外一個建構函式中:

public DropOutInverted(@NonNull INDArray x, @NonNull INDArray z, double p, long n) {
    this.p = p;
    init(x, null, z, n);
}

在呼叫到當前建構函式的時候,呼叫init函式,此時的 z和x是相同的值。

@Override
public void init(INDArray x, INDArray y, INDArray z, long n) {
    super.init(x, y, z, n);
    this.extraArgs = new Object[] {p};
}

執行到當前步,各項引數如下:

x = [-10.00, -9.99, -9.98, -9.97, -9.96, -9.95, -9.94, -9.93, -9.92, -9.91, -9.90, -9.89, -9.88, -9.87, -9.86, -9.85, -9.84, -9.83, -9.82, -9.81]
y = null
z = [-10.00, -9.99, -9.98, -9.97, -9.96, -9.95, -9.94, -9.93, -9.92, -9.91, -9.90, -9.89, -9.88, -9.87, -9.86, -9.85, -9.84, -9.83, -9.82, -9.81]
n = 20
p = 0.5

之後就會跳轉到父類的init()方法:

@Override
public void init(INDArray x, INDArray y, INDArray z, long n) {
    this.x = x;
    this.y = y;
    this.z = z;
    this.n = n;
}

父類方法只是對成員變數進行簡單賦值。
在以上變數初始化完成之後,繼續執行Nd4j.getExecutioner().exec(new DropOutInverted(input, dropout));方法。

/**
 * This method executes specified RandomOp using default RNG available via Nd4j.getRandom()
 *
 * @param op
 */
@Override
public INDArray exec(RandomOp op) {
    return exec(op, Nd4j.getRandom());
}

根據註釋,兩個dropout類是特殊的RandomOp。之後繼續呼叫下一個exec()方法。

/**
 * This method executes specific
 * RandomOp against specified RNG
 *
 * @param op
 * @param rng
 */
@Override
public INDArray exec(RandomOp op, Random rng) {
    if (rng.getStateBuffer() == null)
        throw new IllegalStateException(
                        "You should use one of NativeRandom classes for NativeOperations execution");

    long st = profilingHookIn(op);

    validateDataType(Nd4j.dataType(), op);

    if (op.x() != null && op.y() != null && op.z() != null) {
        // triple arg call
        if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
            loop.execRandomFloat(null, op.opNum(), rng.getStatePointer(), // rng state ptr
                            (FloatPointer) op.x().data().addressPointer(),
                            (IntPointer) op.x().shapeInfoDataBuffer().addressPointer(),
                            (FloatPointer) op.y().data().addressPointer(),
                            (IntPointer) op.y().shapeInfoDataBuffer().addressPointer(),
                            (FloatPointer) op.z().data().addressPointer(),
                            (IntPointer) op.z().shapeInfoDataBuffer().addressPointer(),
                            (FloatPointer) op.extraArgsDataBuff().addressPointer());
        } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
            loop.execRandomDouble(null, op.opNum(), rng.getStatePointer(), // rng state ptr
                            (DoublePointer) op.x().data().addressPointer(),
                            (IntPointer) op.x().shapeInfoDataBuffer().addressPointer(),
                            (DoublePointer) op.y().data().addressPointer(),
                            (IntPointer) op.y().shapeInfoDataBuffer().addressPointer(),
                            (DoublePointer) op.z().data().addressPointer(),
                            (IntPointer) op.z().shapeInfoDataBuffer().addressPointer(),
                            (DoublePointer) op.extraArgsDataBuff().addressPointer());
        }
    } else if (op.x() != null && op.z() != null) {
        //double arg call
        if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
            loop.execRandomFloat(null, op.opNum(), rng.getStatePointer(), // rng state ptr
                            (FloatPointer) op.x().data().addressPointer(),
                            (IntPointer) op.x().shapeInfoDataBuffer().addressPointer(),
                            (FloatPointer) op.z().data().addressPointer(),
                            (IntPointer) op.z().shapeInfoDataBuffer().addressPointer(),
                            (FloatPointer) op.extraArgsDataBuff().addressPointer());
        } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
            loop.execRandomDouble(null, op.opNum(), rng.getStatePointer(), // rng state ptr
                            (DoublePointer) op.x().data().addressPointer(),
                            (IntPointer) op.x().shapeInfoDataBuffer().addressPointer(),
                            (DoublePointer) op.z().data().addressPointer(),
                            (IntPointer) op.z().shapeInfoDataBuffer().addressPointer(),
                            (DoublePointer) op.extraArgsDataBuff().addressPointer());
        }

    } else {
        // single arg call

        if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
            loop.execRandomFloat(null, op.opNum(), rng.getStatePointer(), // rng state ptr
                            (FloatPointer) op.z().data().addressPointer(),
                            (IntPointer) op.z().shapeInfoDataBuffer().addressPointer(),
                            (FloatPointer) op.extraArgsDataBuff().addressPointer());
        } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
            loop.execRandomDouble(null, op.opNum(), rng.getStatePointer(), // rng state ptr
                            (DoublePointer) op.z().data().addressPointer(),
                            (IntPointer) op.z().shapeInfoDataBuffer().addressPointer(),
                            (DoublePointer) op.extraArgsDataBuff().addressPointer());
        }
    }

    profilingHookOut(op, st);

    return op.z();
}

這個函式首先使用validateDataType(Nd4j.dataType(), op);用於檢驗當前資料型別的合法性。然後根據傳入的op的三個成員變數x, y, z來判斷進入哪一分支。在上面的debug資訊我們可以看到,我們的x和z是兩個非空變數,因此進入第二個分支,並且我們當前的Nd4j.dataType()DataBuffer.Type.FLOAT。為此在當前環境下會執行以下語句:

loop.execRandomFloat(null, op.opNum(), rng.getStatePointer(), // rng state ptr
                                (FloatPointer) op.x().data().addressPointer(),
                                (IntPointer) op.x().shapeInfoDataBuffer().addressPointer(),
                                (FloatPointer) op.z().data().addressPointer(),
                                (IntPointer) op.z().shapeInfoDataBuffer().addressPointer(),
                                (FloatPointer) op.extraArgsDataBuff().addressPointer());

然後這部分的具體實現應該是JNI呼叫的底層

public native void execRandomFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, @Cast("Nd4jPointer") Pointer state, FloatPointer x, IntPointer xShapeBuffer, FloatPointer z, IntPointer zShapeBuffer, FloatPointer extraArguments);

經過如上方法的的執行之後,返回z值,這時候通過debug資訊看到的z值為:

[-20.00, -19.98, -19.96, 0.00, 0.00, -19.90, -19.88, 0.00, -19.84, -19.82, 0.00, 0.00, -19.76, -19.74, -19.72, -19.70, -19.68, -19.66, -19.64, 0.00]

因為在前面輸入的時候z其實和x是等同的。在執行以上方法之後,相當於對x做了一個變幻。使得x變為如上的數值。到這裡就使得dl4j的applyDropOutIfNecessary(training)方法部分完成,繼續回到preOutput()方法體內繼續執行。(這裡猜測實現的方式是部分位置隨機置0,然後再所有的資料除以dropout的值)

相關推薦

Dl4j-fit(DataSetIterator iterator)原始碼閱讀dropout

preOut這一部分就是網路模型前向傳播的重點。 public INDArray preOutput(boolean training) { applyDropOutIfNecessary(training); INDArray b = g

DL4J原始碼閱讀:梯度計算

      computeGradientAndScore方法呼叫backprop()做梯度計算和誤差反傳。       backprop()呼叫calcBackpropGradients()方法。calcBackpropGradients()方法再呼叫initGradie

Redis原始碼閱讀叢集-請求分配

    叢集搭建好之後,使用者傳送的命令請求可以被分配到不同的節點去處理。那Redis對命令請求分配的依據是什麼?如果節點數量有變動,命令又是如何重新分配的,重分配的過程是否會阻塞對外提供的服務?接下來會從這兩個問題入手,分析Redis3.0的原始碼實現。 1. 分配依據—

jQuery原始碼閱讀--正則表示式

在jQuery原始碼中,運用了大量的正則表示式,一開始在看的時候真的是一頭霧水,儘管已經看過了JS高程裡面的正則表示式。 今天,看了一篇深入理解正則表示式的文章,對正則表示式有了更深的認識,下面做一個回顧和總結。 正則表示式基礎 JS正則表示式用來匹配

Horizon 原始碼閱讀—— 呼叫Novaclient流程

一、寫在前面 這篇文章主要介紹一下OpenStack(Kilo)關於horizon 呼叫NovaClient的一個分析。         如果轉載,請保留作者資訊。 原文地址:http://blog.csdn.net/u011521019/a

XSStrike原始碼閱讀2——種模式

1.bruteforcer模式 功能介紹 根據使用者提供的payloads檔案去暴力測試每一個引數,以此來確定是否存在xss漏洞(說起來也就是一個兩層迴圈)。 具體實現 XSStrike3.0 bruteforcer.py原始碼如下: import copy from

DL4J原始碼閱讀:LSTM梯度計算

    LSTMHelpers類中的backpropGradientHelper方法是梯度計算過程。         // 本層神經元個數         int hiddenLayerSize = recurrentWeights.size(0); //i.e., n^

LevelDB的源碼閱讀 Compaction操作

left 維護 efault smallest item app apply() body roc leveldb的數據存儲采用LSM的思想,將隨機寫入變為順序寫入,記錄寫入操作日誌,一旦日誌被以追加寫的形式寫入硬盤,就返回寫入成功,由後臺線程將寫入日誌作用於原有的磁盤文件

Flume NG原始碼分析使用ExecSource從本地日誌檔案中收集日誌

常見的日誌收集方式有兩種,一種是經由本地日誌檔案做媒介,非同步地傳送到遠端日誌倉庫,一種是基於RPC方式的同步日誌收集,直接傳送到遠端日誌倉庫。這篇講講Flume NG如何從本地日誌檔案中收集日誌。 ExecSource是用來執行本地shell命令,並把本地日誌檔案中的資料封裝成Event

OpenCV學習筆記30KAZE 演算法原理與原始碼分析KAZE特徵的效能分析與比較

      KAZE系列筆記: 1.  OpenCV學習筆記(27)KAZE 演算法原理與原始碼分析(一)非線性擴散濾波 2.  OpenCV學習筆記(28)KAZE 演算法原理與原始碼分析(二)非線性尺度空間構

GCC原始碼分析——優化

原文連結:http://blog.csdn.net/sonicling/article/details/7916931 一、前言 本篇只介紹一下框架,就不具體介紹每個步驟了。 二、Pass框架 上一篇已經講了gcc的中間語言的表現形式。gcc 對中間語言

Spring原始碼解析——元件註冊4

  /** * 給容器中註冊元件; * 1)、包掃描+元件標註註解(@Controller/@Service/@Repository/@Component)[自己寫的類] * 2)、@Bean[匯入的第三方包裡面的元件] * 3)、@Import[快速給容器中匯入一個

【筆記】ThreadPoolExecutor原始碼閱讀

執行緒數量的維護 執行緒池的大小有兩個重要的引數,一個是corePoolSize(核心執行緒池大小),另一個是maximumPoolSize(最大執行緒大小)。執行緒池主要根據這兩個引數對執行緒池中執行緒的數量進行維護。 需要注意的是,執行緒池建立之初是沒有任何可用執行緒的。只有在有任務到達後,才開始建立

YOLOv2原始碼分析

文章全部YOLOv2原始碼分析 0x01 backward_convolutional_layer void backward_convolutional_layer(convolutional_layer l, network

## Zookeeper原始碼閱讀 Watcher

前言 好久沒有更新部落格了,最近這段時間過得很壓抑,終於開始踏上為換工作準備的正軌了,工作又真的很忙而且很瑣碎,讓自己有點煩惱,希望能早點結束這種狀態。 繼上次分析了ZK的ACL相關程式碼後,ZK裡非常重要的另一個特性就是Watcher機制了。其實在我看來,就ZK的使用而言,Watche機制是最核心的特性

Zookeeper原始碼閱讀 Server端Watcher

前言 前面一篇主要介紹了Watcher介面相關的介面和實體類,但是主要是zk客戶端相關的程式碼,如前一篇開頭所說,client需要把watcher註冊到server端,這一篇分析下server端的watcher。 主要分析Watchmanager類。 Watchmanager 這是WatchMan

Zookeeper原始碼閱讀 ACL基礎

前言 之前看程式碼的時候也同步看了看一些關於zk原始碼的部落格,有一兩篇講到了ZK裡ACL的基礎的結構,我自己這邊也看了看相關的程式碼,在這裡分享一下! ACL和ID ACL和ID都是有Jute生成的實體類,分別代表了ZK裡ACL和不同ACL模式下的具體實體。 ACL: public class A

Redis原始碼閱讀叢集-故障遷移(下)

Redis原始碼閱讀(六)叢集-故障遷移(下)   最近私人的事情比較多,沒有抽出時間來整理部落格。書接上文,上一篇裡總結了Redis故障遷移的幾個關鍵點,以及Redis中故障檢測的實現。本篇主要介紹叢集檢測到某主節點下線後,是如何選舉新的主節點的。注意到Redis叢集是無中心的,那麼使用分散式一

AFNetWorking(3.0)原始碼分析——AFHTTPSessionManager(2)

在上一篇部落格中,我們分析了AFHTTPSessionManager,以及它是如何實現GET/HEAD/PATCH/DELETE相關介面的。 我們還剩下POST相關介面沒有分析,在這篇部落格裡面,我們就來分析一下POST相關介面是如何實現的。 multipart/form-data請

SpringBoot2.0原始碼分析:spring-data-jpa分析

SpringBoot具體整合rabbitMQ可參考:SpringBoot2.0應用(四):SpringBoot2.0之spring-data-jpa JpaRepositories自動注入 當專案中存在org.springframework.data.jpa.repository.JpaRepositor