1. 程式人生 > >RNN與反向傳播演算法(BPTT)的理解

RNN與反向傳播演算法(BPTT)的理解

RNN是序列建模的強大工具。
今天主要搬運兩天來看到的關於RNN的很好的文章:

PS: 第一個連結中的Toy Code做一些說明

圖片名稱

之所以要迴圈8(binary_dim=8)次,是因為輸入是2維的(a和b各輸入一個bit),那麼,每個bit只會影響8個時間戳。因此要注意RNN的訓練,應該以每一個完整的序列(這裡就是a和b兩個八位二進位制數)作為一個training sample,而非以每一次輸入(2 bits)作為一個sample;同樣的在反向傳播時,也同樣遵循這個原則,此處因為每次輸入會影響8個時間戳(或者說每8次輸入為一個完整的training sample),所以要迴圈8次。

再然後,第99行(五角星處)的隱層delta更新法則與上面給出的RNN反向傳播演算法BPTT一文中的下圖正好一致!

這裡寫圖片描述

先寫這麼多。