機器學習(ML)九之GRU、LSTM、深度神經網路、雙向迴圈神經網路
阿新 • • 發佈:2020-02-15
門控迴圈單元(GRU)
迴圈神經網路中的梯度計算方法。當時間步數較大或者時間步較小時,迴圈神經網路的梯度較容易出現衰減或爆炸。雖然裁剪梯度可以應對梯度爆炸,但無法解決梯度衰減的問題。通常由於這個原因,迴圈神經網路在實際中較難捕捉時間序列中時間步距離較大的依賴關係。
門控迴圈神經網路(gated recurrent neural network)的提出,正是為了更好地捕捉時間序列中時間步距離較大的依賴關係。它通過可以學習的門來控制資訊的流動。其中,門控迴圈單元(gated recurrent unit,GRU)是一種常用的門控迴圈神經網路。
門控迴圈單元
門控迴圈單元的設計。它引入了重置門(reset gate)和更新門(update gate)的概念,從而修改了迴圈神經網路中隱藏狀態的計算方式。
重置門和更新門
門控迴圈單元中的重置門和更新門的輸入均為當前時間步輸入Xt與上一時間步隱藏狀態Ht−1,輸出由啟用函式為sigmoid函式的全連線層計算得到。
候選隱藏狀態
隱藏狀態
程式碼實現
1 #!/usr/bin/env python 2 # coding: utf-8 3 4 # In[10]: 5 6 7 import d2lzh as d2l 8 from mxnet import nd 9 from mxnet.gluon import rnn 10 import zipfile 11 12 13 # In[11]: 14 15 16 def load_data_jay_lyrics(file): 17 """Load the Jay Chou lyric data set (available in the Chinese book).""" 18 with zipfile.ZipFile(file) as zin: 19 with zin.open('jaychou_lyrics.txt') as f: 20 corpus_chars = f.read().decode('utf-8') 21 corpus_chars = corpus_chars.replace('\n', ' ').replace('\r', ' ') 22 corpus_chars = corpus_chars[0:10000] 23 idx_to_char = list(set(corpus_chars)) 24 char_to_idx = dict([(char, i) for i, char in enumerate(idx_to_char)]) 25 vocab_size = len(char_to_idx) 26 corpus_indices = [char_to_idx[char] for char in corpus_chars] 27 return corpus_indices, char_to_idx, idx_to_char, vocab_size 28 29 30 # In[12]: 31 32 33 file ='/Users/James/Documents/dev/test/data/jaychou_lyrics.txt.zip' 34 (corpus_indices, char_to_idx, idx_to_char, vocab_size) = load_data_jay_lyrics(file) 35 36 37 # In[13]: 38 39 40 num_inputs, num_hiddens, num_outputs = vocab_size, 256, vocab_size 41 ctx = d2l.try_gpu() 42 43 def get_params(): 44 def _one(shape): 45 return nd.random.normal(scale=0.01, shape=shape, ctx=ctx) 46 47 def _three(): 48 return (_one((num_inputs, num_hiddens)), 49 _one((num_hiddens, num_hiddens)), 50 nd.zeros(num_hiddens, ctx=ctx)) 51 52 W_xz, W_hz, b_z = _three() # 更新門引數 53 W_xr, W_hr, b_r = _three() # 重置門引數 54 W_xh, W_hh, b_h = _three() # 候選隱藏狀態引數 55 # 輸出層引數 56 W_hq = _one((num_hiddens, num_outputs)) 57 b_q = nd.zeros(num_outputs, ctx=ctx) 58 # 附上梯度 59 params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q] 60 for param in params: 61 param.attach_grad() 62 return params 63 64 65 # In[14]: 66 67 68 def init_gru_state(batch_size, num_hiddens, ctx): 69 return (nd.zeros(shape=(batch_size, num_hiddens), ctx=ctx), ) 70 71 72 # In[15]: 73 74 75 def gru(inputs, state, params): 76 W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params 77 H, = state 78 outputs = [] 79 for X in inputs: 80 Z = nd.sigmoid(nd.dot(X, W_xz) + nd.dot(H, W_hz) + b_z) 81 R = nd.sigmoid(nd.dot(X, W_xr) + nd.dot(H, W_hr) + b_r) 82 H_tilda = nd.tanh(nd.dot(X, W_xh) + nd.dot(R * H, W_hh) + b_h) 83 H = Z * H + (1 - Z) * H_tilda 84 Y = nd.dot(H, W_hq) + b_q 85 outputs.append(Y) 86 return outputs, (H,) 87 88 89 # In[16]: 90 91 92 num_epochs, num_steps, batch_size, lr, clipping_theta = 160, 35, 32, 1e2, 1e-2 93 pred_period, pred_len, prefixes = 40, 50, ['分開', '不分開'] 94 95 96 # In[ ]: 97 98 99 d2l.train_and_predict_rnn(gru, get_params, init_gru_state, num_hiddens, 100 vocab_size, ctx, corpus_indices, idx_to_char, 101 char_to_idx, False, num_epochs, num_steps, lr, 102 clipping_theta, batch_size, pred_period, pred_len, 103 prefixes)
長短期記憶(LSTM)
常用的門控迴圈神經網路:長短期記憶(long short-term memory,LSTM)。它比門控迴圈單元的結構稍微複雜一點。
長短期記憶
LSTM 中引入了3個門,即輸入門(input gate)、遺忘門(forget gate)和輸出門(output gate),以及與隱藏狀態形狀相同的記憶細胞(某些文獻把記憶細胞當成一種特殊的隱藏狀態),從而記錄額外的資訊。
輸入門、遺忘門和輸出門
候選記憶細胞
記憶細胞
隱藏狀態
程式碼實現
1 #LSTM 初始化引數 2 num_inputs, num_hiddens, num_outputs = vocab_size, 256, vocab_size 3 ctx = d2l.try_gpu() 4 5 def get_params(): 6 def _one(shape): 7 return nd.random.normal(scale=0.01, shape=shape, ctx=ctx) 8 9 def _three(): 10 return (_one((num_inputs, num_hiddens)), 11 _one((num_hiddens, num_hiddens)), 12 nd.zeros(num_hiddens, ctx=ctx)) 13 14 W_xi, W_hi, b_i = _three() # 輸入門引數 15 W_xf, W_hf, b_f = _three() # 遺忘門引數 16 W_xo, W_ho, b_o = _three() # 輸出門引數 17 W_xc, W_hc, b_c = _three() # 候選記憶細胞引數 18 # 輸出層引數 19 W_hq = _one((num_hiddens, num_outputs)) 20 b_q = nd.zeros(num_outputs, ctx=ctx) 21 # 附上梯度 22 params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, 23 b_c, W_hq, b_q] 24 for param in params: 25 param.attach_grad() 26 return params 27 28 29 # In[19]: 30 31 32 def init_lstm_state(batch_size, num_hiddens, ctx): 33 return (nd.zeros(shape=(batch_size, num_hiddens), ctx=ctx), 34 nd.zeros(shape=(batch_size, num_hiddens), ctx=ctx))