1. 程式人生 > >Pytorch學習之LSTM.md

Pytorch學習之LSTM.md

Pytorch學習之LSTM

看了理解LSTM這篇博文,在這裡寫寫自己對LSTM網路的一些認識!。

  • RNN
  • 網路計算過程

Recurrent Neural Networks

人類並不是每時每刻都從一片空白的大腦開始他們的思考。在你閱讀這篇文章時候,你都是基於自己已經擁有的對先前所見詞的理解來推斷當前詞的真實含義。我們不會將所有的東西都全部丟棄,然後用空白的大腦進行思考。我們的思想擁有永續性。 傳統的神經網路並不能做到這點,看起來也像是一種巨大的弊端。例如,假設你希望對電影中的每個時間點的時間型別進行分類。傳統的神經網路應該很難來處理這個問題——使用電影中先前的事件推斷後續的事件。 RNN 解決了這個問題。RNN 是包含迴圈的網路,允許資訊的持久化

在這裡插入圖片描述 這是一個經典的RNN的流程圖。

1. LSTM網路

經典的LSTM的流程圖:

在這裡插入圖片描述

相信大家都看過這個圖(盜用別人的圖)。 再來一段公式,就是下面的,公式來自Pytorch。 hth_t is the hidden state at time tt , ctc_t is the cell state at time tt , xtx_t is the input at time tt, h(t1)h_{(t-1)} is the hidden state of the previous layer at time t1t-1 or the initial hidden state at time 0

0 , and iti_t , ftf_t , gtg_t , oto_t are the input, forget, cell, and output gates, respectively. σ\sigma is the sigmoid function.

2. 內部計算分析

rnn = nn.LSTM(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
c0 = torch.randn(2, 3, 20)
output, (hn, cn) = rnn(input, (h0, c0))

在這裡插入圖片描述

可以看到引數的大小變成了(4*20,10),是標準RNN的四倍。原因是這裡它包括了四個引數矩陣W

iiW_{ii}WifW_{if}WigW_{ig}WioW_{io},它們每一個都是(20×10),輸入的維度大小是(10×1), 這樣iti_t , ftf_t , gtg_t , oto_t 的維度都是(20×1),公式(5)(6)的運算應該是叉積(元素積),這樣得到的ctc_thth_t的維度才能是20。

在這裡插入圖片描述 如上圖所示hn和cn的最後一維都是20。注意這裡的LSTM網路是單向,雙向的要*2。蟹蟹!