1. 程式人生 > >記憶網路RNN、LSTM與GRU

記憶網路RNN、LSTM與GRU

這裡寫圖片描述

一般的神經網路輸入和輸出的維度大小都是固定的,針對序列型別(尤其是變長的序列)的輸入或輸出資料束手無策。RNN通過採用具有記憶的隱含層單元解決了序列資料的訓練問題。LSTM、GRU屬於RNN的改進,解決了RNN中梯度消失爆炸的問題,屬於序列資料訓練的常用方案。

RNN

結構

傳統的神經網路的輸入和輸出都是確定的,RNN的輸入和輸出都是不確定的sequence資料。其結構如下:

這裡寫圖片描述
這裡寫圖片描述

具體地,RNN有隱含層,隱含層也是記憶層,其狀態(權值)會傳遞到下一個狀態中。

htyt=σ(xtWxh+ht1Whh)=σ(htWhy)

訓練

訓練步驟如下:

  1. 構建損失函式
  2. 求損失函式對權值的梯度
  3. 採用梯度下降法更新權值引數

關於損失函式,根據需要選擇構建即可,下面提供兩種常見的損失函式:

CC=12n=1N||ynŷ n||2=12n=1Nlogynrn

關於梯度下降,採用BPTT(Backpropagation through time)演算法,該演算法的核心是對每一個時間戳,計算該時間戳中權重的梯度,然後更新權重。需要注意的是,不同時間戳同樣權重的梯度可能是不一樣的,如下圖所示都減去,相當於更新同一塊記憶體區域中的權重。

這裡寫圖片描述
這裡寫圖片描述

應用

  • 多對多:詞性標註pos tagging、語音識別、name entity recognition(區分poeple、organizations、places、information extration(區分place of departure、destination、time of departure、time of arrival, other)、機器翻譯
  • 多對一:情感分析
  • 一對多:caption generation

這裡寫圖片描述

這裡寫圖片描述
這裡寫圖片描述

這裡寫圖片描述

這裡寫圖片描述

RNN Variants

RNN的變種大致包含下面3個思路:

  • 增加隱含層的輸入引數:例如除了ht1,xt,還可以包含yt1作為輸入。
  • 增加隱含層的深度
  • 雙向RNN

這裡寫圖片描述
這裡寫圖片描述
這裡寫圖片描述

LSTM

結構

  • 單個時間戳,RNN輸入1個x,輸出1個y
  • 單個時間戳,LSTM輸入4個x,輸出1個y

相比RNN,LSTM的輸入多了3個x,對應3個gate,這3個gate分別是:

  • input gate:控制輸入
  • forget gate:控制cell
  • output gate:控制輸出

涉及到的啟用函式共5個,其中3個控制gate的(通常用sigmoid函式,模擬gate的開閉狀態),1個作用於輸入上,一個作用於cell的輸出上。

這裡寫圖片描述

LSTM單個時間戳的具體執行如下:

  • 輸入:4個輸入x,1個cell的狀態c
  • 輸出:1個輸出a,1個更新的cell狀態c
ca=g(z)f(zi)+cf(zf)=h(c)f(zo)

梯度消失及梯度爆炸

首先,要明白RNN中梯度消失與梯度爆炸的原因:在時間戳的更新中,cell的狀態不斷乘以Whh。簡單起見,視Whh為scalar值w,那麼y=xwnyw=nxwn1。根據w的值與1的大小關係,梯度會消失或者爆炸。

接下來,要明白LSTM如何解決RNN中梯度消失與爆炸的問題。

針對梯度消失,RNN中當獲取c的梯度後,因為c=cw,為了backward獲得c的梯度,要將c的梯度乘以w;LSTM中存在梯度的快速通道,獲取c的梯度後,因為c=g(z)f(zi)+cf(zf),當forget gate開啟時,c=g(z)f(zi)+cc的梯度可以直接傳遞給c
總結來說,LSTM相比RNN,將c,c的更新關係從乘法變成了加法,因此不用乘以權值係數wc的梯度可以直接傳遞給c,解決了梯度消失的問題。

針對梯度爆炸,即使將c,c的關係由乘法變成了加法,仍然解決不了梯度爆炸。原因便是梯度的路徑不止一條,如下圖所示,紅色的塊仍然可能造成梯度爆炸。LSTM解決這個問題的方法是clip,也就是設定梯度最大值,超過最大值的按最大值計。

這裡寫圖片描述

GRU

結構

GRU相比LSTM的3個gate,只用了兩個gate:

  • update gate:zt
  • reset gate: