Pytorch學習之LSTM.md
Pytorch學習之LSTM
看了理解LSTM這篇博文,在這裡寫寫自己對LSTM網路的一些認識!。
- RNN
- 網路計算過程
Recurrent Neural Networks
人類並不是每時每刻都從一片空白的大腦開始他們的思考。在你閱讀這篇文章時候,你都是基於自己已經擁有的對先前所見詞的理解來推斷當前詞的真實含義。我們不會將所有的東西都全部丟棄,然後用空白的大腦進行思考。我們的思想擁有永續性。 傳統的神經網路並不能做到這點,看起來也像是一種巨大的弊端。例如,假設你希望對電影中的每個時間點的時間型別進行分類。傳統的神經網路應該很難來處理這個問題——使用電影中先前的事件推斷後續的事件。 RNN 解決了這個問題。RNN 是包含迴圈的網路,允許資訊的持久化
這是一個經典的RNN的流程圖。
1. LSTM網路
經典的LSTM的流程圖:
相信大家都看過這個圖(盜用別人的圖)。 再來一段公式,就是下面的,公式來自Pytorch。 is the hidden state at time , is the cell state at time , is the input at time , is the hidden state of the previous layer at time or the initial hidden state at time , and , , , are the input, forget, cell, and output gates, respectively. is the sigmoid function.
2. 內部計算分析
rnn = nn.LSTM(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
c0 = torch.randn(2, 3, 20)
output, (hn, cn) = rnn(input, (h0, c0))
可以看到引數的大小變成了(4*20,10),是標準RNN的四倍。原因是這裡它包括了四個引數矩陣、、、,它們每一個都是(20×10),輸入的維度大小是(10×1), 這樣 , , , 的維度都是(20×1),公式(5)(6)的運算應該是叉積(元素積),這樣得到的和的維度才能是20。
如上圖所示hn和cn的最後一維都是20。注意這裡的LSTM網路是單向,雙向的要*2。蟹蟹!