1. 程式人生 > 其它 >GRU(門控迴圈單元)原理+ 程式碼實現

GRU(門控迴圈單元)原理+ 程式碼實現

GRU說白了就是加了兩個門,這兩個門控制最終隱藏狀態的輸出,其中還是那一套換湯不換藥。

R是重置門,決定上一個時間步\(h_{t-1}\)是否要被重置,如果R元素全為0,很顯然我們就丟掉了上一個時間步的h資訊。

S是更新門,決定了這個時刻的候選隱藏狀態\(h_{t}^{\prime}\)應該怎麼輸出。

注意,因為這是兩個閥門,閥門控制肯定取值只有(0~1),所以這個的啟用函式是sigmod函式。

公式:

\[\begin{aligned} \mathbf{R}_{t} &=\sigma\left(\mathbf{X}_{t} \mathbf{W}_{x r}+\mathbf{H}_{t-1} \mathbf{W}_{h r}+\mathbf{b}_{r}\right) \\ \mathbf{Z}_{t} &=\sigma\left(\mathbf{X}_{t} \mathbf{W}_{x z}+\mathbf{H}_{t-1} \mathbf{W}_{h z}+\mathbf{b}_{z}\right) \end{aligned} \]
候選隱藏狀態
\[\tilde{\mathbf{H}}_{t}=\tanh \left(\mathbf{X}_{t} \mathbf{W}_{x h}+\left(\mathbf{R}_{t} \odot \mathbf{H}_{t-1}\right) \mathbf{W}_{h h}+\mathbf{b}_{h}\right) \]

值得注意的是,這裡因為R和Z都是起到了閥門的作用,所有很顯然它是直接做哈達瑪乘積的,即對應元素相乘。

可以看到,通過重置門,我們得到了候選隱藏狀態,這個做的好處是可以減少一萬狀態的影響。

更新隱藏狀態
\[\mathbf{H}_{t}=\mathbf{Z}_{t} \odot \mathbf{H}_{t-1}+\left(1-\mathbf{Z}_{t}\right) \odot \tilde{\mathbf{H}}_{t} \]

通過更新門實現了對隱藏狀態的更新。

如果Z接近1,那麼\(h_{t-1}\)就會被保留,而如果整個子序列的所有時間步的更新門,也就是 Z 都接近1,那麼我們可以保留從序列起始時間步開始的所有隱藏狀態。

重置門有利於捕獲序列中的短期依賴關係。

更新門有助於補貨序列中的長期依賴關係。

從零開始實現

import torch 
import torch.nn as nn
import torch.nn.functional as F
import math
class GRU(nn.Module):
    def __init__(self,indim, hidim, outdim):
        super(GRU, self).__init__()
        self.indim = indim
        self.hidim = hidim
        self.outdim = outdim
        self.W_zh, self.W_zx, self.b_z = self.get_three_parameters()
        self.W_rh, self.W_rx, self.b_r = self.get_three_parameters()
        self.W_hh, self.W_hx, self.b_h = self.get_three_parameters()
        self.Linear = nn.Linear(hidim, outdim) # 全連線層做輸出
        self.reset()

    def forward(self, input, state):
        input = input.type(torch.float32)
        if torch.cuda.is_available():
            input = input.cuda()
        Y = []   
        h = state
        h = h.cuda()
        for x in input:
            z = F.sigmoid(h @ self.W_zh + x @ self.W_zx + self.b_z)
            r = F.sigmoid(h @ self.W_rh + x @ self.W_rx + self.b_r)
            ht = F.tanh((h * r) @ self.W_hh + x @ self.W_hx + self.b_h)
            h = (1 - z) * h + z * ht
            y = self.Linear(h)
            Y.append(y)
        return torch.cat(Y, dim=0), h
          
    def get_three_parameters(self):
        indim, hidim, outdim = self.indim, self.hidim, self.outdim               
        return nn.Parameter(torch.FloatTensor(hidim, hidim)), \
            nn.Parameter(torch.FloatTensor(indim, hidim)), \
                nn.Parameter(torch.FloatTensor(hidim))
                
    def reset(self):
        stdv = 1.0 / math.sqrt(self.hidim)
        for param in self.parameters():
            nn.init.uniform_(param, -stdv, stdv)

就是按公式原原本本寫了一遍,沒什麼特點,就像搭積木一樣。

框架實現

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from RNN import *
setup_seed(916)

class GRU(nn.Module):
    def __init__(self, indim, hidim, outdim):
        super(GRU, self).__init__()
        self.GRU = nn.GRU(indim, hidim)
        self.Linear = nn.Linear(hidim, outdim)
    
    def forward(self, input, state):
        input = input.type(torch.float32)
        h = state.unsqueeze(0)
        if torch.cuda.is_available():
            input = input.cuda()
            h = h.cuda()
        y, state = self.GRU(input, h)
        output = self.Linear(y.reshape(-1, y.shape[-1]))
        return output, state     

這裡有個值得注意的點,由於框架它實際上是可以定義多層RNN的,所以它輸入和輸出張量的維度不一樣。

gru(input, h_0)

輸入 input 就是(time_step, batch_size, feature_dim),在模型初始化nn.GRU()傳入引數batch_first,那麼input的shape就是(batch_size, time_step,feature_dim),這一點需要當心。

對於h_0,它的shape簡單(D* num_layers, N, H_{hidim}) , 這裡的D是看我們在初始化的時候是否設定了bidirectional, 如果true,代表我們要用雙向的rnn,於是D就為2.不過大部分情況下我們都只用單向的rnn,於是一般來說它的shape就是(num_layers, N, H_{hidim}) ,如果不顯式地給出h0,框架會自動用全0來構造這個h0,如果只是訓練的話,是沒必要自己初始化一個h0的,當然預測肯定要傳入h0。

對於這個輸出的結果,我們需要更加註意。

(output,h_n) = ouputs

其中output的shape為(L, N, D * H_{out}),如果設定了batch_first = True, 就顛倒一下,一個重要的點:如果我們設定了雙向的RNN,那麼我們最後是將兩個隱藏層結果concat起來了,所以,最後一維就是D * H_{out}這是一個需要留心的點。

h_n也需要注意,他是最後一個時間步的每一層的隱藏狀態(D * num_layers, N, H_{out}),如果我們設定層數為1,並且不使用雙向的rnn,那麼輸出的結果就是(1, N, h_out)

這些維度挺繞的,所以一定要留心一點。