DL4J原始碼閱讀(七):LSTM梯度計算
LSTMHelpers類中的backpropGradientHelper方法是梯度計算過程。
// 本層神經元個數
int hiddenLayerSize = recurrentWeights.size(0); //i.e., n^L // 前一層神經元個數
int prevLayerSize = inputWeights.size(0); //n^(L-1)
// 一批資料量
int miniBatchSize = epsilon.size(0);
boolean is2dInput = epsilon
int timeSeriesLength = (is2dInput ? 1 : epsilon.size(2));
INDArray wFFTranspose = null;
INDArray wOOTranspose = null;
INDArray wGGTranspose = null;
//
if (hasPeepholeConnections) {
// 下面三個是三個窺視孔對應的權重
wFFTranspose = recurrentWeights.get(NDArrayIndex.all(), point(4 * hiddenLayerSize)).transpose();
wOOTranspose = recurrentWeights.get(NDArrayIndex.all(), point(4 * hiddenLayerSize + 1)).transpose();
wGGTranspose
}
// 從迴圈權重矩陣中分離出三個門的權重。從程式碼效率上看,這條語句應該和上面的窺視孔判斷結合起來。如果沒有窺視孔,可以直接讓wIFOG = recurrentWeights。有窺視孔才這樣操作。
INDArray wIFOG = recurrentWeights.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 4 * hiddenLayerSize));
//F order here so that content for time steps are together
// 這個epsilonNext和本層輸入矩陣的結構是一致的。
INDArray epsilonNext = Nd4j.create(new int[] {miniBatchSize, prevLayerSize, timeSeriesLength}, 'f'); //i.e., what would be W^L*(delta^L)^T. Shape: [m,n^(L-1),T]
INDArray nablaCellStateNext = null;
// 這裡初始化了一個全為0的矩陣,後面四個矩陣從這個矩陣中擷取。當後面四個矩陣變化時,deltaifogNext 也會相應變化。
INDArray deltaifogNext = Nd4j.create(new int[] {miniBatchSize, 4 * hiddenLayerSize}, 'f');
INDArray deltaiNext = deltaifogNext.get(NDArrayIndex.all(), NDArrayIndex.interval(0, hiddenLayerSize));
INDArray deltafNext = deltaifogNext.get(NDArrayIndex.all(),
NDArrayIndex.interval(hiddenLayerSize, 2 * hiddenLayerSize));
INDArray deltaoNext = deltaifogNext.get(NDArrayIndex.all(),
NDArrayIndex.interval(2 * hiddenLayerSize, 3 * hiddenLayerSize));
INDArray deltagNext = deltaifogNext.get(NDArrayIndex.all(),
NDArrayIndex.interval(3 * hiddenLayerSize, 4 * hiddenLayerSize));
Level1 l1BLAS = Nd4j.getBlasWrapper().level1();
int endIdx = 0;
if (truncatedBPTT) {
endIdx = Math.max(0, timeSeriesLength - tbpttBackwardLength);
}
// 從梯度檢視中獲取輸入、迴圈、偏移三種權重,並都置為0
INDArray iwGradientsOut = gradientViews.get(inputWeightKey);
INDArray rwGradientsOut = gradientViews.get(recurrentWeightKey); //Order: {I,F,O,G,FF,OO,GG}
INDArray bGradientsOut = gradientViews.get(biasWeightKey);
iwGradientsOut.assign(0);
rwGradientsOut.assign(0);
bGradientsOut.assign(0);
// 都是0,和上邊說過的一樣,應該和下面窺視孔的判斷結合起來。
INDArray rwGradientsIFOG =
rwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 4 * hiddenLayerSize));
INDArray rwGradientsFF = null;
INDArray rwGradientsOO = null;
INDArray rwGradientsGG = null;
if (hasPeepholeConnections) {
rwGradientsFF = rwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.point(4 * hiddenLayerSize));
rwGradientsOO = rwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.point(4 * hiddenLayerSize + 1));
rwGradientsGG = rwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.point(4 * hiddenLayerSize + 2));
}
if (helper != null) {
Pair<Gradient, INDArray> ret = helper.backpropGradient(conf, gateActivationFn, input, recurrentWeights,
inputWeights, epsilon, truncatedBPTT, tbpttBackwardLength, fwdPass, forwards,
inputWeightKey, recurrentWeightKey, biasWeightKey, gradientViews, maskArray,
hasPeepholeConnections);
if (ret != null) {
return ret;
}
}
boolean sigmoidGates = gateActivationFn instanceof ActivationSigmoid;
IActivation afn = ((org.deeplearning4j.nn.conf.layers.BaseLayer) conf.getLayer()).getActivationFn();
// we check, if we have defined workspace here. If we don't - we working without workspace, and we're skipping internal LSTM one. Otherwise - we go for it
MemoryWorkspace workspace = Nd4j.getMemoryManager().getCurrentWorkspace() != null && !Nd4j.getMemoryManager()
.getCurrentWorkspace().getId().equals(ComputationGraph.workspaceExternal)
? Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(
ComputationGraph.workspaceConfigurationLSTM,
ComputationGraph.workspaceLSTM)
: null;
INDArray timeStepMaskColumn = null;
// 按時間序列長度迴圈,是倒序。
for (int iTimeIndex = timeSeriesLength - 1; iTimeIndex >= endIdx; iTimeIndex--) {
// we're emulating try block here
if (workspace != null)
workspace.notifyScopeEntered();
int time = iTimeIndex;
int inext = 1;
if (!forwards) {
time = timeSeriesLength - iTimeIndex - 1;
inext = -1;
}
//First: calclate the components of nablaCellState that relies on the next time step deltas, so we can overwrite the deltas
INDArray nablaCellState;
if (iTimeIndex != timeSeriesLength - 1 && hasPeepholeConnections) {
nablaCellState = deltafNext.dup('f').muliRowVector(wFFTranspose);
l1BLAS.axpy(nablaCellState.length(), 1.0, deltagNext.dup('f').muliRowVector(wGGTranspose),
nablaCellState);
} else {
nablaCellState = Nd4j.create(new int[] {miniBatchSize, hiddenLayerSize}, 'f');
}
// 前一時間步記憶狀態
INDArray prevMemCellState = (iTimeIndex == 0 ? fwdPass.prevMemCell : fwdPass.memCellState[time - inext]);
// 前一時間步單元輸出
INDArray prevHiddenUnitActivation =
(iTimeIndex == 0 ? fwdPass.prevAct : fwdPass.fwdPassOutputAsArrays[time - inext]);
// 當前時間步記憶狀態
INDArray currMemCellState = fwdPass.memCellState[time];
//LSTM unit output errors (dL/d(a_out)); not to be confused with \delta=dL/d(z_out)
// 對應時間步的誤差切片
INDArray epsilonSlice = (is2dInput ? epsilon : epsilon.tensorAlongDimension(time, 1, 0)); //(w^{L+1}*(delta^{(L+1)t})^T)^T or equiv.
INDArray nablaOut = Shape.toOffsetZeroCopy(epsilonSlice, 'f'); //Shape: [m,n^L]
if (iTimeIndex != timeSeriesLength - 1) {
//if t == timeSeriesLength-1 then deltaiNext etc are zeros
Nd4j.gemm(deltaifogNext, wIFOG, nablaOut, false, true, 1.0, 1.0);
}
// 當前狀態經過啟用函式(一般是tanh)的結果
INDArray sigmahOfS = fwdPass.memCellActivations[time];
// 輸出門啟用(相乘之前)
INDArray ao = fwdPass.oa[time];
//Normally would use zo.dup() in above line, but won't be using zo again (for this time step). Ditto for zf, zg, zi
INDArray deltao = deltaoNext;
// 對應時間步誤差矩陣與當前記憶單元狀態經過啟用函式(一般是tanh)的結果矩陣的哈達瑪積。這是獲取輸出門的誤差分量(對於一層的誤差,每個門都有貢獻)
Nd4j.getExecutioner().exec(new MulOp(nablaOut, sigmahOfS, deltao));
if (sigmoidGates) {
// TimesOneMinus的計算公式是x*(1-x),這個和sigmoid的求導公式是一樣的。
INDArray sigmaoPrimeOfZo = Nd4j.getExecutioner().execAndReturn(new TimesOneMinus(ao.dup('f'))); //Equivalent to sigmoid deriv on zo
// 誤差乘以導數,得到本分量誤差
deltao.muli(sigmaoPrimeOfZo);
} else {
deltao.assign(gateActivationFn.backprop(fwdPass.oz[time], deltao).getFirst()); //Deltao needs to be modified in-place
//TODO: optimize (no assign)
}
//Memory cell error:
// 當前記憶單元狀態誤差反傳(狀態的誤差)。ao.muli(nablaOut)這個是獲取輸出誤差中當前記憶狀態的分量。和上邊的new MulOp(nablaOut, sigmahOfS, deltao)正好相反。
INDArray temp = afn.backprop(currMemCellState.dup('f'), ao.muli(nablaOut)).getFirst(); //TODO activation functions with params
l1BLAS.axpy(nablaCellState.length(), 1.0, temp, nablaCellState);
if (hasPeepholeConnections) {
// 輸出門窺視孔誤差分量
INDArray deltaMulRowWOO = deltao.dup('f').muliRowVector(wOOTranspose);
// 狀態誤差加上窺視孔誤差
l1BLAS.axpy(nablaCellState.length(), 1.0, deltaMulRowWOO, nablaCellState); //nablaCellState.addi(deltao.mulRowVector(wOOTranspose));
}
if (iTimeIndex != timeSeriesLength - 1) {
INDArray nextForgetGateAs = fwdPass.fa[time + inext];
int length = nablaCellState.length();
l1BLAS.axpy(length, 1.0, nextForgetGateAs.muli(nablaCellStateNext), nablaCellState); //nablaCellState.addi(nextForgetGateAs.mul(nablaCellStateNext))
}
//Store for use in next iteration, and IF we're in workspace, we need to push it out of current workspace
//單元狀態誤差儲存起來,下次迴圈時使用
nablaCellStateNext = workspace == null ? nablaCellState : nablaCellState.leverage();
//Forget gate delta:
// 遺忘門啟用
INDArray af = fwdPass.fa[time];
INDArray deltaf = null;
if (iTimeIndex > 0 || prevMemCellState != null) { //For time == 0 && no prevMemCellState, equivalent to muli by 0
//Note that prevMemCellState may be non-null at t=0 for TBPTT
deltaf = deltafNext;
if (sigmoidGates) {
// 求導。和前面輸出門求導相似,不過這次是兩個引數,對af求導,結果儲存到deltaf中。
Nd4j.getExecutioner().exec(new TimesOneMinus(af, deltaf));
// 導數乘以單元狀態誤差
deltaf.muli(nablaCellState);
// 乘以前記憶單元狀態
deltaf.muli(prevMemCellState);
} else {
INDArray temp2 = nablaCellState.mul(prevMemCellState);
deltaf.assign(gateActivationFn.backprop(fwdPass.fz[time].dup('f'), temp2).getFirst()); //deltaf needs to be modified in-place
//TODO activation functions with params
}
}
//Shape: [m,n^L]
//Input modulation gate delta:
// 輸入門啟用
INDArray ag = fwdPass.ga[time];
// 輸入啟用
INDArray ai = fwdPass.ia[time];
INDArray deltag = deltagNext;
if (sigmoidGates) {
// 求導。和遺忘門的一樣。
Nd4j.getExecutioner().exec(new TimesOneMinus(ag, deltag)); //Equivalent to sigmoid deriv on zg
// 導數乘以輸入啟用
deltag.muli(ai);
// 再乘以單元狀態誤差
deltag.muli(nablaCellState);
} else {
INDArray temp2 = Nd4j.getExecutioner().execAndReturn(
new MulOp(ai, nablaCellState, Nd4j.createUninitialized(ai.shape(), 'f')));
deltag.assign(gateActivationFn.backprop(fwdPass.gz[time], temp2).getFirst());
//TODO activation functions with params; optimize (no assign)
}
//Shape: [m,n^L]
//Network input delta:
// 輸入乘以權重後還沒有經過啟用函式時的矩陣
INDArray zi = fwdPass.iz[time];
INDArray deltai = deltaiNext;
// 輸入門啟用乘以單元狀態誤差
temp = Nd4j.getExecutioner().execAndReturn(
new MulOp(ag, nablaCellState, Nd4j.createUninitialized(deltai.shape(), 'f')));
// 輸入門誤差分量
deltai.assign(afn.backprop(zi, temp).getFirst());
//TODO activation functions with params; also: optimize this (no assign)
//Shape: [m,n^L]
//Handle masking
if (maskArray != null) {
//Mask array is present: bidirectional RNN -> need to zero out these errors to avoid using errors from a masked time step
// to calculate the parameter gradients. Mask array has shape [minibatch, timeSeriesLength] -> get column(this time step)
timeStepMaskColumn = maskArray.getColumn(time);
deltaifogNext.muliColumnVector(timeStepMaskColumn);
//Later, the deltaifogNext is used to calculate: input weight gradients, recurrent weight gradients, bias gradients
}
// 前層啟用切片,就是本層輸入切片
INDArray prevLayerActivationSlice =
Shape.toMmulCompatible(is2dInput ? input : input.tensorAlongDimension(time, 1, 0));
if (iTimeIndex > 0 || prevHiddenUnitActivation != null) { //For time == 0 && no prevMemCellState, equivalent to muli by 0
//Note that prevHiddenUnitActivations may be non-null at t=0 for TBPTT
//Again, deltaifog_current == deltaifogNext at this point... same array
// 輸入的轉置乘以誤差,得輸入權重梯度 Nd4j.gemm(prevLayerActivationSlice, deltaifogNext, iwGradientsOut, true, false, 1.0, 1.0);
} else {
INDArray iwGradients_i =
iwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.interval(0, hiddenLayerSize));
Nd4j.gemm(prevLayerActivationSlice, deltai, iwGradients_i, true, false, 1.0, 1.0);
INDArray iwGradients_og = iwGradientsOut.get(NDArrayIndex.all(),
NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize));
INDArray deltaog = deltaifogNext.get(NDArrayIndex.all(),
NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize));
Nd4j.gemm(prevLayerActivationSlice, deltaog, iwGradients_og, true, false, 1.0, 1.0);
}
if (iTimeIndex > 0 || prevHiddenUnitActivation != null) {
//If t==0 and prevHiddenUnitActivation==null, equiv. to zeros(n^L,n^L), so dL/dW for recurrent weights
// will end up as 0 anyway
//At this point: deltaifog and deltaifogNext are the same thing...
//So what we are actually doing here is sum of (prevAct^transpose * deltaifog_current)
// 前時間步隱藏單元啟用轉置乘以誤差,得迴圈權重梯度
Nd4j.gemm(prevHiddenUnitActivation, deltaifogNext, rwGradientsIFOG, true, false, 1.0, 1.0);
//Shape: [1,n^L]. sum(0) is sum over examples in mini-batch.
//Can use axpy here because result of sum and rwGradients[4 to 6] have order Nd4j.order(), via Nd4j.create()
if (hasPeepholeConnections) {
// 遺忘門和輸入門窺視孔的輸入是前記憶單元狀態
INDArray dLdwFF = deltaf.dup('f').muli(prevMemCellState).sum(0); //mul not mmul because these weights are from unit j->j only (whereas other recurrent weights are i->j for all i,j)
l1BLAS.axpy(hiddenLayerSize, 1.0, dLdwFF, rwGradientsFF); //rwGradients[4].addi(dLdwFF); //dL/dw_{FF}
INDArray dLdwGG = deltag.dup('f').muli(prevMemCellState).sum(0);
l1BLAS.axpy(hiddenLayerSize, 1.0, dLdwGG, rwGradientsGG); //rwGradients[6].addi(dLdwGG);
}
}
if (hasPeepholeConnections) {
// 輸出門窺視孔的輸入是當前記憶單元狀態
INDArray dLdwOO = deltao.dup('f').muli(currMemCellState).sum(0); //Expected shape: [n^L,1]. sum(0) is sum over examples in mini-batch.
l1BLAS.axpy(hiddenLayerSize, 1.0, dLdwOO, rwGradientsOO); //rwGradients[5].addi(dLdwOO); //dL/dw_{OOxy}
}
if (iTimeIndex > 0 || prevHiddenUnitActivation != null) { //For time == 0 && no prevMemCellState, equivalent to muli by 0
//Note that prevHiddenUnitActivation may be non-null at t=0 for TBPTT
// 計算偏移梯度
l1BLAS.axpy(4 * hiddenLayerSize, 1.0, deltaifogNext.sum(0), bGradientsOut);
} else {
l1BLAS.axpy(hiddenLayerSize, 1.0, deltai.sum(0), bGradientsOut); //Sneaky way to do bGradients_i += deltai.sum(0)
INDArray ogBiasToAdd = deltaifogNext.get(NDArrayIndex.all(),
NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize)).sum(0);
INDArray ogBiasGrad = bGradientsOut.get(NDArrayIndex.point(0),
NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize));
l1BLAS.axpy(2 * hiddenLayerSize, 1.0, ogBiasToAdd, ogBiasGrad);
}
//Calculate epsilonNext - i.e., equiv. to what would be (w^L*(d^(Lt))^T)^T in a normal network
//But here, need to add 4 weights * deltas for the IFOG gates
INDArray epsilonNextSlice = epsilonNext.tensorAlongDimension(time, 1, 0); //This slice: f order and contiguous, due to epsilonNext being defined as f order.
if (iTimeIndex > 0 || prevHiddenUnitActivation != null) {
//Note that prevHiddenUnitActivation may be non-null at t=0 for TBPTT
// 本層誤差乘以輸入權重轉置,得上一層的誤差分量
Nd4j.gemm(deltaifogNext, inputWeights, epsilonNextSlice, false, true, 1.0, 1.0);
} else {
//No contribution from forget gate at t=0
INDArray wi = inputWeights.get(NDArrayIndex.all(), NDArrayIndex.interval(0, hiddenLayerSize));
Nd4j.gemm(deltai, wi, epsilonNextSlice, false, true, 1.0, 1.0);
INDArray deltaog = deltaifogNext.get(NDArrayIndex.all(),
NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize));
INDArray wog = inputWeights.get(NDArrayIndex.all(),
NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize));
Nd4j.gemm(deltaog, wog, epsilonNextSlice, false, true, 1.0, 1.0); //epsilonNextSlice.addi(deltao.mmul(woTranspose)).addi(deltag.mmul(wgTranspose));
}
if (maskArray != null) {
//Mask array is present: bidirectional RNN -> need to zero out these errors to avoid sending anything
// but 0s to the layer below at this time step (for the given example)
epsilonNextSlice.muliColumnVector(timeStepMaskColumn);
}
if (workspace != null)
workspace.close();
}
Gradient retGradient = new DefaultGradient();
retGradient.gradientForVariable().put(inputWeightKey, iwGradientsOut);
retGradient.gradientForVariable().put(recurrentWeightKey, rwGradientsOut);
retGradient.gradientForVariable().put(biasWeightKey, bGradientsOut);
return new Pair<>(retGradient, epsilonNext);