1. 程式人生 > >RNN系列

RNN系列

import torch
import torch.nn as nn

seed = 0
torch.manual_seed( seed )



def compute_diff( t1, t2 ):
    return (t1 - t2).mean().item()






################### 研究GRU的計算原理 ###################




# 5表示序列長度
# 7表示batch_size
# 10表示輸入x的維度
x_seq = torch.randn(5, 7, 10)




# 1表示RNN層數(即使是1層也不能少了這一維)
# 7表示batch_size
# 20表示隱藏狀態h的維度 h_init = torch.randn(1, 7, 20) # 定義1層GRU gru = nn.GRU( input_size=10, hidden_size=20, num_layers=1 ) # GRU引數 W_hh, b_hh = gru.weight_hh_l0, gru.bias_hh_l0 W_ih, b_ih = gru.weight_ih_l0, gru.bias_ih_l0 # W_hh實際上是由( W_hr | W_hz | W_hn )拼接而成 W_hr, W_hz, W_hn = W_hh[0:20], W_hh[20:40], W_hh[40
:60] b_hr, b_hz, b_hn = b_hh[0:20], b_hh[20:40], b_hh[40:60] # W_ih實際上是由( W_ir | W_iz | W_in )拼接而成 W_ir, W_iz, W_in = W_ih[0:20], W_ih[20:40], W_ih[40:60] b_ir, b_iz, b_in = b_ih[0:20], b_ih[20:40], b_ih[40:60] print( '\nStep1: 計算GRU(檢查output_seq[-1]和h_last是否相等)' ) output_seq, h_last = gru( x_seq, h_init ) # output_seq.size() = [5, 7, 20]
# h_last.size() = [1, 7, 20] # GRU的輸出值一定是hidden狀態組成的序列 # 即GRU將序列x轉換為序列hidden # RNN直接將hidden狀態用作輸出 # 所以最後的output_seq的最後一個元素,與h_last相等 print( 'equal =', torch.equal( output_seq[-1], h_last.squeeze() ) ) print( 'diff =', compute_diff( output_seq[-1], h_last.squeeze() ) ) # 自己實現GRU中的計算 print( '\nStep2: 自己實現GRU中的計算' ) my_output_seq = torch.zeros( [5, 7, 20] ) ht = h_init.squeeze() for i in range(5): xt = x_seq[i] h_prev = ht rt = torch.sigmoid( xt.mm( W_ir.t() ) + b_ir + h_prev.mm( W_hr.t() ) + b_hr ) # reset gate zt = torch.sigmoid( xt.mm( W_iz.t() ) + b_iz + h_prev.mm( W_hz.t() ) + b_hz ) # update gate nt = torch.tanh( xt.mm( W_in.t() ) + b_in + rt * ( h_prev.mm( W_hn.t() ) + b_hn ) ) # new gate ht = ( 1 - zt ) * nt + zt * h_prev my_output_seq[i] = ht print( 'equal =', torch.equal( my_output_seq, output_seq ) ) print( 'diff =', compute_diff( my_output_seq, output_seq ) ) print( '【注】由於浮點數計算的誤差,不可能保證計算結果完全相等,但diff值足夠小' ) print( '\nStep3: 通過RNNCell的方式進行計算' ) gru_cell = nn.GRUCell(10, 20) gru_cell.weight_hh, gru_cell.bias_hh = W_hh, b_hh gru_cell.weight_ih, gru_cell.bias_ih = W_ih, b_ih hx = h_init.squeeze() out = torch.zeros( [5, 7, 20] ) for i in range(5): hx = gru_cell( x_seq[i], hx ) out[i] = hx print( 'equal =', torch.equal( out, output_seq ) ) print( 'diff =', compute_diff( out, output_seq ) )