1. 程式人生 > >DL4J原始碼閱讀(七):LSTM梯度計算

DL4J原始碼閱讀(七):LSTM梯度計算

    LSTMHelpers類中的backpropGradientHelper方法是梯度計算過程。

        // 本層神經元個數

        int hiddenLayerSize = recurrentWeights.size(0); //i.e., n^L        // 前一層神經元個數

        int prevLayerSize = inputWeights.size(0); //n^(L-1)

        // 一批資料量

        int miniBatchSize = epsilon.size(0);

        boolean is2dInput = epsilon

.rank() < 3; //Edge case: T=1 may have shape [miniBatchSize,n^(L+1)], equiv. to [miniBatchSize,n^(L+1),1]

        int timeSeriesLength = (is2dInput ? 1 : epsilon.size(2));

        INDArray wFFTranspose = null;

        INDArray wOOTranspose = null;

        INDArray wGGTranspose = null;

        //

 窺視孔

        if (hasPeepholeConnections) {

            // 下面三個是三個窺視孔對應的權重

            wFFTranspose = recurrentWeights.get(NDArrayIndex.all(), point(4 * hiddenLayerSize)).transpose();

            wOOTranspose = recurrentWeights.get(NDArrayIndex.all(), point(4 * hiddenLayerSize + 1)).transpose();

            wGGTranspose

 = recurrentWeights.get(NDArrayIndex.all(), point(4 * hiddenLayerSize + 2)).transpose();

        }

        // 從迴圈權重矩陣中分離出三個門的權重。從程式碼效率上看,這條語句應該和上面的窺視孔判斷結合起來。如果沒有窺視孔,可以直接讓wIFOG = recurrentWeights。有窺視孔才這樣操作。

        INDArray wIFOG = recurrentWeights.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 4 * hiddenLayerSize));

        //F order here so that content for time steps are together

// 這個epsilonNext和本層輸入矩陣的結構是一致的。

        INDArray epsilonNext = Nd4j.create(new int[] {miniBatchSize, prevLayerSize, timeSeriesLength}, 'f'); //i.e., what would be W^L*(delta^L)^T. Shape: [m,n^(L-1),T]

        INDArray nablaCellStateNext = null;

        // 這裡初始化了一個全為0的矩陣,後面四個矩陣從這個矩陣中擷取。當後面四個矩陣變化時,deltaifogNext 也會相應變化。

        INDArray deltaifogNext = Nd4j.create(new int[] {miniBatchSize, 4 * hiddenLayerSize}, 'f');

        INDArray deltaiNext = deltaifogNext.get(NDArrayIndex.all(), NDArrayIndex.interval(0, hiddenLayerSize));

        INDArray deltafNext = deltaifogNext.get(NDArrayIndex.all(),

                        NDArrayIndex.interval(hiddenLayerSize, 2 * hiddenLayerSize));

        INDArray deltaoNext = deltaifogNext.get(NDArrayIndex.all(),

                        NDArrayIndex.interval(2 * hiddenLayerSize, 3 * hiddenLayerSize));

        INDArray deltagNext = deltaifogNext.get(NDArrayIndex.all(),

                        NDArrayIndex.interval(3 * hiddenLayerSize, 4 * hiddenLayerSize));

        Level1 l1BLAS = Nd4j.getBlasWrapper().level1();

        int endIdx = 0;

        if (truncatedBPTT) {

            endIdx = Math.max(0, timeSeriesLength - tbpttBackwardLength);

        }

        // 從梯度檢視中獲取輸入、迴圈、偏移三種權重,並都置為0

        INDArray iwGradientsOut = gradientViews.get(inputWeightKey);

        INDArray rwGradientsOut = gradientViews.get(recurrentWeightKey); //Order: {I,F,O,G,FF,OO,GG}

        INDArray bGradientsOut = gradientViews.get(biasWeightKey);

        iwGradientsOut.assign(0);

        rwGradientsOut.assign(0);

        bGradientsOut.assign(0);

        // 都是0,和上邊說過的一樣,應該和下面窺視孔的判斷結合起來。

        INDArray rwGradientsIFOG =

                        rwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 4 * hiddenLayerSize));

        INDArray rwGradientsFF = null;

        INDArray rwGradientsOO = null;

        INDArray rwGradientsGG = null;

        if (hasPeepholeConnections) {

            rwGradientsFF = rwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.point(4 * hiddenLayerSize));

            rwGradientsOO = rwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.point(4 * hiddenLayerSize + 1));

            rwGradientsGG = rwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.point(4 * hiddenLayerSize + 2));

        }

        if (helper != null) {

            Pair<Gradient, INDArray> ret = helper.backpropGradient(conf, gateActivationFn, input, recurrentWeights,

                            inputWeights, epsilon, truncatedBPTT, tbpttBackwardLength, fwdPass, forwards,

                            inputWeightKey, recurrentWeightKey, biasWeightKey, gradientViews, maskArray,

                            hasPeepholeConnections);

            if (ret != null) {

                return ret;

            }

        }

        boolean sigmoidGates = gateActivationFn instanceof ActivationSigmoid;

        IActivation afn = ((org.deeplearning4j.nn.conf.layers.BaseLayer) conf.getLayer()).getActivationFn();

        // we check, if we have defined workspace here. If we don't - we working without workspace, and we're skipping internal LSTM one. Otherwise - we go for it

        MemoryWorkspace workspace = Nd4j.getMemoryManager().getCurrentWorkspace() != null && !Nd4j.getMemoryManager()

                        .getCurrentWorkspace().getId().equals(ComputationGraph.workspaceExternal)

                                        ? Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(

                                                        ComputationGraph.workspaceConfigurationLSTM,

                                                        ComputationGraph.workspaceLSTM)

                                        : null;

        INDArray timeStepMaskColumn = null;

        // 按時間序列長度迴圈,是倒序。

        for (int iTimeIndex = timeSeriesLength - 1; iTimeIndex >= endIdx; iTimeIndex--) {

            // we're emulating try block here

            if (workspace != null)

                workspace.notifyScopeEntered();

            int time = iTimeIndex;

            int inext = 1;

            if (!forwards) {

                time = timeSeriesLength - iTimeIndex - 1;

                inext = -1;

            }

            //First: calclate the components of nablaCellState that relies on the next time step deltas, so we can overwrite the deltas

            INDArray nablaCellState;

            if (iTimeIndex != timeSeriesLength - 1 && hasPeepholeConnections) {

                nablaCellState = deltafNext.dup('f').muliRowVector(wFFTranspose);

                l1BLAS.axpy(nablaCellState.length(), 1.0, deltagNext.dup('f').muliRowVector(wGGTranspose),

                                nablaCellState);

            } else {

                nablaCellState = Nd4j.create(new int[] {miniBatchSize, hiddenLayerSize}, 'f');

            }

            // 前一時間步記憶狀態

            INDArray prevMemCellState = (iTimeIndex == 0 ? fwdPass.prevMemCell : fwdPass.memCellState[time - inext]);

            // 前一時間步單元輸出

            INDArray prevHiddenUnitActivation =

                            (iTimeIndex == 0 ? fwdPass.prevAct : fwdPass.fwdPassOutputAsArrays[time - inext]);

            // 當前時間步記憶狀態

            INDArray currMemCellState = fwdPass.memCellState[time];

            //LSTM unit output errors (dL/d(a_out)); not to be confused with \delta=dL/d(z_out)

// 對應時間步的誤差切片

            INDArray epsilonSlice = (is2dInput ? epsilon : epsilon.tensorAlongDimension(time, 1, 0)); //(w^{L+1}*(delta^{(L+1)t})^T)^T or equiv.

            INDArray nablaOut = Shape.toOffsetZeroCopy(epsilonSlice, 'f'); //Shape: [m,n^L]

            if (iTimeIndex != timeSeriesLength - 1) {

                //if t == timeSeriesLength-1 then deltaiNext etc are zeros

                Nd4j.gemm(deltaifogNext, wIFOG, nablaOut, false, true, 1.0, 1.0);

            }

            // 當前狀態經過啟用函式(一般是tanh)的結果

            INDArray sigmahOfS = fwdPass.memCellActivations[time];

            // 輸出門啟用(相乘之前)

            INDArray ao = fwdPass.oa[time];

            //Normally would use zo.dup() in above line, but won't be using zo again (for this time step). Ditto for zf, zg, zi

            INDArray deltao = deltaoNext;

            // 對應時間步誤差矩陣與當前記憶單元狀態經過啟用函式(一般是tanh)的結果矩陣的哈達瑪積。這是獲取輸出門的誤差分量(對於一層的誤差,每個門都有貢獻)

            Nd4j.getExecutioner().exec(new MulOp(nablaOut, sigmahOfS, deltao));

            if (sigmoidGates) {

                // TimesOneMinus的計算公式是x*(1-x),這個和sigmoid的求導公式是一樣的。

                INDArray sigmaoPrimeOfZo = Nd4j.getExecutioner().execAndReturn(new TimesOneMinus(ao.dup('f'))); //Equivalent to sigmoid deriv on zo

                // 誤差乘以導數,得到本分量誤差

                deltao.muli(sigmaoPrimeOfZo);

            } else {

                deltao.assign(gateActivationFn.backprop(fwdPass.oz[time], deltao).getFirst()); //Deltao needs to be modified in-place

                //TODO: optimize (no assign)

            }

            //Memory cell error:

// 當前記憶單元狀態誤差反傳(狀態的誤差)。ao.muli(nablaOut)這個是獲取輸出誤差中當前記憶狀態的分量。和上邊的new MulOp(nablaOut, sigmahOfS, deltao)正好相反。

            INDArray temp = afn.backprop(currMemCellState.dup('f'), ao.muli(nablaOut)).getFirst(); //TODO activation functions with params

            l1BLAS.axpy(nablaCellState.length(), 1.0, temp, nablaCellState);

            if (hasPeepholeConnections) {

                // 輸出門窺視孔誤差分量

                INDArray deltaMulRowWOO = deltao.dup('f').muliRowVector(wOOTranspose);

                // 狀態誤差加上窺視孔誤差

                l1BLAS.axpy(nablaCellState.length(), 1.0, deltaMulRowWOO, nablaCellState); //nablaCellState.addi(deltao.mulRowVector(wOOTranspose));

            }

            if (iTimeIndex != timeSeriesLength - 1) {

                INDArray nextForgetGateAs = fwdPass.fa[time + inext];

                int length = nablaCellState.length();

                l1BLAS.axpy(length, 1.0, nextForgetGateAs.muli(nablaCellStateNext), nablaCellState); //nablaCellState.addi(nextForgetGateAs.mul(nablaCellStateNext))

            }

            //Store for use in next iteration, and IF we're in workspace, we need to push it out of current workspace

//單元狀態誤差儲存起來,下次迴圈時使用

            nablaCellStateNext = workspace == null ? nablaCellState : nablaCellState.leverage();

            //Forget gate delta:

            // 遺忘門啟用

            INDArray af = fwdPass.fa[time];

            INDArray deltaf = null;

            if (iTimeIndex > 0 || prevMemCellState != null) { //For time == 0 && no prevMemCellState, equivalent to muli by 0

                //Note that prevMemCellState may be non-null at t=0 for TBPTT

                deltaf = deltafNext;

                if (sigmoidGates) {

                    // 求導。和前面輸出門求導相似,不過這次是兩個引數,對af求導,結果儲存到deltaf中。

                    Nd4j.getExecutioner().exec(new TimesOneMinus(af, deltaf));

                    // 導數乘以單元狀態誤差

                    deltaf.muli(nablaCellState);

                    // 乘以前記憶單元狀態

                    deltaf.muli(prevMemCellState);

                } else {

                    INDArray temp2 = nablaCellState.mul(prevMemCellState);

                    deltaf.assign(gateActivationFn.backprop(fwdPass.fz[time].dup('f'), temp2).getFirst()); //deltaf needs to be modified in-place

                    //TODO activation functions with params

                }

            }

            //Shape: [m,n^L]

            //Input modulation gate delta:

            // 輸入門啟用

            INDArray ag = fwdPass.ga[time];

            // 輸入啟用

            INDArray ai = fwdPass.ia[time];

            INDArray deltag = deltagNext;

            if (sigmoidGates) {

                // 求導。和遺忘門的一樣。

                Nd4j.getExecutioner().exec(new TimesOneMinus(ag, deltag)); //Equivalent to sigmoid deriv on zg

                // 導數乘以輸入啟用

                deltag.muli(ai);

                // 再乘以單元狀態誤差

                deltag.muli(nablaCellState);

            } else {

                INDArray temp2 = Nd4j.getExecutioner().execAndReturn(

                                new MulOp(ai, nablaCellState, Nd4j.createUninitialized(ai.shape(), 'f')));

                deltag.assign(gateActivationFn.backprop(fwdPass.gz[time], temp2).getFirst());

                //TODO activation functions with params; optimize (no assign)

            }

            //Shape: [m,n^L]

            //Network input delta:

            // 輸入乘以權重後還沒有經過啟用函式時的矩陣

            INDArray zi = fwdPass.iz[time];

            INDArray deltai = deltaiNext;

            // 輸入門啟用乘以單元狀態誤差

            temp = Nd4j.getExecutioner().execAndReturn(

                            new MulOp(ag, nablaCellState, Nd4j.createUninitialized(deltai.shape(), 'f')));

            // 輸入門誤差分量

            deltai.assign(afn.backprop(zi, temp).getFirst());

            //TODO activation functions with params; also: optimize this (no assign)

            //Shape: [m,n^L]

            //Handle masking

            if (maskArray != null) {

                //Mask array is present: bidirectional RNN -> need to zero out these errors to avoid using errors from a masked time step

                // to calculate the parameter gradients.  Mask array has shape [minibatch, timeSeriesLength] -> get column(this time step)

                timeStepMaskColumn = maskArray.getColumn(time);

                deltaifogNext.muliColumnVector(timeStepMaskColumn);

                //Later, the deltaifogNext is used to calculate: input weight gradients, recurrent weight gradients, bias gradients

            }

            // 前層啟用切片,就是本層輸入切片

            INDArray prevLayerActivationSlice =

                            Shape.toMmulCompatible(is2dInput ? input : input.tensorAlongDimension(time, 1, 0));

            if (iTimeIndex > 0 || prevHiddenUnitActivation != null) { //For time == 0 && no prevMemCellState, equivalent to muli by 0

                //Note that prevHiddenUnitActivations may be non-null at t=0 for TBPTT

                //Again, deltaifog_current == deltaifogNext at this point... same array

                // 輸入的轉置乘以誤差,得輸入權重梯度                Nd4j.gemm(prevLayerActivationSlice, deltaifogNext, iwGradientsOut, true, false, 1.0, 1.0);

            } else {

                INDArray iwGradients_i =

                                iwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.interval(0, hiddenLayerSize));

                Nd4j.gemm(prevLayerActivationSlice, deltai, iwGradients_i, true, false, 1.0, 1.0);

                INDArray iwGradients_og = iwGradientsOut.get(NDArrayIndex.all(),

                                NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize));

                INDArray deltaog = deltaifogNext.get(NDArrayIndex.all(),

                                NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize));

                Nd4j.gemm(prevLayerActivationSlice, deltaog, iwGradients_og, true, false, 1.0, 1.0);

            }

            if (iTimeIndex > 0 || prevHiddenUnitActivation != null) {

                //If t==0 and prevHiddenUnitActivation==null, equiv. to zeros(n^L,n^L), so dL/dW for recurrent weights

                // will end up as 0 anyway

                //At this point: deltaifog and deltaifogNext are the same thing...

                //So what we are actually doing here is sum of (prevAct^transpose * deltaifog_current)

                // 前時間步隱藏單元啟用轉置乘以誤差,得迴圈權重梯度

                Nd4j.gemm(prevHiddenUnitActivation, deltaifogNext, rwGradientsIFOG, true, false, 1.0, 1.0);

                //Shape: [1,n^L]. sum(0) is sum over examples in mini-batch.

                //Can use axpy here because result of sum and rwGradients[4 to 6] have order Nd4j.order(), via Nd4j.create()

                if (hasPeepholeConnections) {

                    // 遺忘門和輸入門窺視孔的輸入是前記憶單元狀態

                    INDArray dLdwFF = deltaf.dup('f').muli(prevMemCellState).sum(0); //mul not mmul because these weights are from unit j->j only (whereas other recurrent weights are i->j for all i,j)

                    l1BLAS.axpy(hiddenLayerSize, 1.0, dLdwFF, rwGradientsFF); //rwGradients[4].addi(dLdwFF);    //dL/dw_{FF}

                    INDArray dLdwGG = deltag.dup('f').muli(prevMemCellState).sum(0);

                    l1BLAS.axpy(hiddenLayerSize, 1.0, dLdwGG, rwGradientsGG); //rwGradients[6].addi(dLdwGG);

                }

            }

            if (hasPeepholeConnections) {

                 // 輸出門窺視孔的輸入是當前記憶單元狀態

                INDArray dLdwOO = deltao.dup('f').muli(currMemCellState).sum(0); //Expected shape: [n^L,1]. sum(0) is sum over examples in mini-batch.

                l1BLAS.axpy(hiddenLayerSize, 1.0, dLdwOO, rwGradientsOO); //rwGradients[5].addi(dLdwOO);    //dL/dw_{OOxy}

            }

            if (iTimeIndex > 0 || prevHiddenUnitActivation != null) { //For time == 0 && no prevMemCellState, equivalent to muli by 0

                //Note that prevHiddenUnitActivation may be non-null at t=0 for TBPTT

                // 計算偏移梯度

                l1BLAS.axpy(4 * hiddenLayerSize, 1.0, deltaifogNext.sum(0), bGradientsOut);

            } else {

                l1BLAS.axpy(hiddenLayerSize, 1.0, deltai.sum(0), bGradientsOut); //Sneaky way to do bGradients_i += deltai.sum(0)

                INDArray ogBiasToAdd = deltaifogNext.get(NDArrayIndex.all(),

                                NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize)).sum(0);

                INDArray ogBiasGrad = bGradientsOut.get(NDArrayIndex.point(0),

                                NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize));

                l1BLAS.axpy(2 * hiddenLayerSize, 1.0, ogBiasToAdd, ogBiasGrad);

            }

            //Calculate epsilonNext - i.e., equiv. to what would be (w^L*(d^(Lt))^T)^T in a normal network

            //But here, need to add 4 weights * deltas for the IFOG gates

            INDArray epsilonNextSlice = epsilonNext.tensorAlongDimension(time, 1, 0); //This slice: f order and contiguous, due to epsilonNext being defined as f order.

            if (iTimeIndex > 0 || prevHiddenUnitActivation != null) {

                //Note that prevHiddenUnitActivation may be non-null at t=0 for TBPTT

                // 本層誤差乘以輸入權重轉置,得上一層的誤差分量

                Nd4j.gemm(deltaifogNext, inputWeights, epsilonNextSlice, false, true, 1.0, 1.0);

            } else {

                //No contribution from forget gate at t=0

                INDArray wi = inputWeights.get(NDArrayIndex.all(), NDArrayIndex.interval(0, hiddenLayerSize));

                Nd4j.gemm(deltai, wi, epsilonNextSlice, false, true, 1.0, 1.0);

                INDArray deltaog = deltaifogNext.get(NDArrayIndex.all(),

                                NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize));

                INDArray wog = inputWeights.get(NDArrayIndex.all(),

                                NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize));

                Nd4j.gemm(deltaog, wog, epsilonNextSlice, false, true, 1.0, 1.0); //epsilonNextSlice.addi(deltao.mmul(woTranspose)).addi(deltag.mmul(wgTranspose));

            }

            if (maskArray != null) {

                //Mask array is present: bidirectional RNN -> need to zero out these errors to avoid sending anything

                // but 0s to the layer below at this time step (for the given example)

                epsilonNextSlice.muliColumnVector(timeStepMaskColumn);

            }

            if (workspace != null)

                workspace.close();

        }

        Gradient retGradient = new DefaultGradient();

        retGradient.gradientForVariable().put(inputWeightKey, iwGradientsOut);

        retGradient.gradientForVariable().put(recurrentWeightKey, rwGradientsOut);

        retGradient.gradientForVariable().put(biasWeightKey, bGradientsOut);

        return new Pair<>(retGradient, epsilonNext);