機器學習-->深度學習-->RNN,LSTM
本篇博文將把我所學習的迴圈神經網路(RNN)和長短時記憶神經網路(LSTM)做一個詳細的總結分析。
RNN
為什麼會有迴圈神經網路
傳統的神經網路(包括CNN),輸入輸出都是互相獨立的 。但是在一些場景中,後續的輸出和之前的內容是相關的。比如一段句子的輸出“我在CSDN上發表了一篇文_”,這裡後續的輸出和前面的文字有著很大的聯絡。傳統的神經網路對這樣的任務無法很好的進行預測。
迴圈神經網路結構
簡單來看,就是把序列按時間展開。
是時間t處的輸入。
是時間t處的“記憶”,,由上面的結構可以看出St包含了St-1,St-2等之前的資訊。f是一些啟用函式,例如sigmoid,tanh,reLu等。
是時間t處的輸出,可能是softmax輸出的屬於每個候選詞的概率,。
可以把隱狀態St視作“記憶體”,捕捉了之前時間點上的的資訊。
輸出Ot由當前時間及之前所有的“記憶”共同計算得到。
實際應用中,St不能保留和利用之前所有的資訊。
不同與CNN,RNN神經網路都共享一組引數(U,V,W),極大減小了需要訓練和預估計的引數數量。
在某些實際場景中,不需要每個時刻都輸出一個Ot,只需要最後output即可。
我覺得可以這樣理解RNN,把多個傳統神經網路的隱藏層神經元連線起來,每個神經網路表示某一時刻輸入輸出的一個模型。這樣把多個連續時刻的模型連在一起,就可以結合上下文(前後時刻的內容)來進行分析預測了。
雙向RNN
有些請看下,序列不只是依賴當前的情況還可能依賴後面的情況。比如補全一段文字裡面的某一個詞。
深層雙向RNN
和雙向RNN的區別是每一步/每一個時間點,我們設定多層結構。
BPTT(BackPropagation Through Time)演算法
在DNN或者CNN中用BP演算法來更新引數,BP演算法的核心就是利用鏈式求導法則計算error對每個引數的偏導。
而BPTT和BP是一個思路,也是損失函式對引數求偏導,只不過既然有step,就和時間t有關係。
其中W,U,V為引數矩陣
在t時刻的交叉熵損失函式:
所有時刻的交叉熵損失函式:
損失函式對引數W求偏導:
例如在t3時刻對引數W求偏導得:
其中S3等於:
明顯S3依賴於S2,根據鏈式求導法則:
於是由 可推得:
這樣就求得了損失函式對引數W的偏導,再由: 來更新引數W。
LSTM
上面介紹的RNN神經網路解決了前後資訊儲存的問題。但是它有一個很要命的缺陷:隨著時間距離的增大,RNN會喪失聯絡距它很遠資訊的能力。
LSTM大體結構
LSTM是RNN的一個變種,其大體結構和RNN相差不大,區別在於他的記憶細胞被改造了。
:可理解為啟用函式層,例如執行sigmoid,tanh等啟用。
:表示對向量或矩陣做元素級操作,或者說是逐點的執行運算,比如加法,減法等運算。
:資訊矩陣的傳送方向。
:向量或矩陣的連線運算。
:複製操作。
細胞狀態
上圖中的實黑線可理解為記憶細胞C在傳送帶上沿著時間t在傳送。h(t-1)表示上一時刻的輸出,我們可以看出,實黑線上記憶細胞C以非線性的形式,儲存了每個當前時刻的上一時刻的輸出資訊,而記憶細胞沿著傳送帶源源不斷的傳送,則最終的記憶細胞儲存了之前所有的資訊(雖然通過非線性的過濾,只能選擇性的保留部分資訊)。
這裡的表示執行sigmoid操作,輸出一個在(0,1)範圍內的概率值,資訊乘以這個概率值則可實現資訊過濾的功能。同理表示執行雙曲正切。
第一步:忘記門
這裡表示將上一時刻的輸出h(t-1)矩陣和當前時刻的輸入Xt矩陣連線成一個矩陣,然後放進sigmoid函式內,得到一個截斷概率,這個階段概率再與C(t-1)相乘,則表示會從上一個傳來的狀態細胞 中丟棄什麼資訊,儲存什麼資訊。故稱為忘記門。
第二步:決定增加哪些新的資訊到“細胞狀態”中
- Sigmoid層決定什麼值需要更新(輸出一個概率值,決定過濾哪些資訊)
- Tanh層建立一個新的候選值 向量
- 上述兩步是為狀態更新做準備。
第三步:更新細胞狀態
這一步是整個LSTM的關鍵。
- 更新C(t-1)為Ct。
- 舊狀態C(t-1)與忘記門得出概率 ft 相乘,決定忘記上一時刻細胞狀態C(t-1)的哪部分資訊,也即是過濾掉細胞狀態裡的部分資訊。
- 加上 ,這就是過濾後的候選值。
- 上面兩項求和,第一項時過濾後的細胞狀態資訊;第二項是候選值資訊。兩項求和得到新的細胞狀態。
第四步:基於細胞狀態得到輸出
第一個公式是用sigmoid函式輸出一個0到1的概率值,也就是確定細胞狀態哪部分將輸出。
Ct是更新後的細胞狀態,用tanh函式處理細胞狀態(得到一個-1到1之間的值),在將他和sigmoid輸出的概率作乘積,輸出我們確定要輸出的那部分資訊。
至此LSTM神經網路內部結構大概就是如此。
我們再把LSTM和RNN做個對比:
在RNN中的記憶細胞資訊是直接通過一些簡單的線性加權然後再線性變換得到,如下圖所示:
那麼有:
而在整個LSTM更新過程,都是先產生一個過濾概率,然後再新增一些資訊,在求和等等,其實質和RNN基本一致,都是去掉舊資訊裡面不重要部分,再加上新資訊裡面重要的部分。只是過濾資訊方式上面略有不同。
總結
LSTM為什麼用sigmoid函式?
在傳統的神經網路(包括CNN)中,啟用函式 一般不用sigmoid函式,因為sigmoid函式很容易產生梯度消失現象。那麼LSTM中為什麼還要用sigmoid函式?
我個人的理解是:在LSTM中,啟用函式並不是僅僅用來啟用,刷選過濾資訊,啟用函式還有一個更重要的目的要實現,就是要使得記憶細胞在沿著傳送帶不斷傳送之前的記憶時,還需要保證資訊不會膨脹,要保證資訊量在一定範圍內。那麼這就必須使用sigmoid函式作為啟用函式,reLu函式做不到這一點,我們可以看看sigmoid函式和reLu函式影象就可明白。
sigmoid函式影象:
reLu函式影象:大於0時,有多少資訊就傳多少資訊。
很明顯sigmoid函式始終是在0到1之間,而relu則不一定,故啟用函式若用reLu函式,那麼記憶資訊在傳遞過程中可能會膨脹的越來越大。
為什麼LSTM比RNN更能解決長時間依賴的問題
我們回顧上面RNN反向更新引數的結論:
在RNN中記憶細胞的正向傳遞公式:
,這裡面f可以假定視為tanh函式,很明顯這是一個複合函式,複合函式求偏導時,是一個連乘的形式,我們結合傳統神經網路裡面BP演算法方向更新引數的過程可知,根據鏈式求導 法則,在最後一層依次向前求引數偏導得(上面有詳細的公式推導):
很顯然,根據這個求偏導的公式可知,距離當前時刻越遠,那麼鏈式法則求偏導,連乘的項越多(注意這裡面全是一項乘另外一項然後再乘以另外一項等等,只要有一項很小甚至接近於0,他對整體梯度的影響很大,並且隨著距離梯度累積的越來越多,其梯度變化就會越來越小或者是越來越大),則就很有可能產生梯度消失或梯度爆炸現象。也即是距離當前越近,越不容易產生梯度消失(學到的越多),越遠越容易產生梯度消失(學到的越少)。所以RNN會喪失儲存距離很遠之前的資訊的能力。
而在LSTM中記憶細胞資訊傳遞公式為:
它並不是一個複合函式(即外面沒有套一層函式),而僅僅是兩項求和的形式,那麼在求偏導時,也是兩項求和的形式,這時根據鏈式求導法則,梯度往前更新時,可以視為沿著兩條軸向回傳,一項等於0接近0,整體的梯度也不會接近0。這裡的思路和殘差神經網路裡面避免梯度消失的思路有點類似。
故LSTM與RNN最大不同之處,就是在方向更新引數時把一個練乘形式變成一個求和的形式,這樣即使在學習距離很遠的資訊時,也不會出現很嚴重的梯度消失問題。
利用pytorch實現RNN分類
#coding:utf-8
import torch
from torch import nn
from torch.autograd import Variable
import torchvision.datasets as dsets
import torchvision.transforms as transforms
#import matplotlib.pyplot as plt
torch.manual_seed(1) # reproducible
# Hyper Parameters
EPOCH = 1 # 訓練整批資料次數,為了節省時間只訓練一次
BATCH_SIZE = 64
TIME_STEP = 28 # rnn 時間步數/圖片高度
INPUT_SIZE = 28 # rnn 每步輸入值 / 圖片寬度
LR = 0.01 # 學習率
DOWNLOAD_MNIST = True # set to True if haven't download the data
# Mnist digital dataset
train_data = dsets.MNIST(
root='./mnist/',
train=True, # this is training data
transform=transforms.ToTensor(), # Converts a PIL.Image or numpy.ndarray to
# torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
download=DOWNLOAD_MNIST, # download it if you don't have it
)
# 批訓練50 samples,1 channel,28*28,故(50,1,28,28)
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
# 為了節省時間,只測試前2000個樣本
test_data = dsets.MNIST(root='./mnist/', train=False, transform=transforms.ToTensor())
test_x = Variable(test_data.test_data, volatile=True).type(torch.FloatTensor)[:2000]/255. # shape (2000, 28, 28) value in range(0,1)
test_y = test_data.test_labels.numpy().squeeze()[:2000] # covert to numpy array
class RNN(nn.Module):
def __init__(self):
super(RNN, self).__init__()
self.rnn = nn.LSTM( # 如果用單純的RNN,則很難收斂
input_size=INPUT_SIZE, # 輸入特徵數量,這裡是圖片寬度
hidden_size=64, # rnn 隱藏層神經元
num_layers=1, # RNN的層數
batch_first=True, # input會以batch_size為第一維度的特徵集(batch_size,time_step,input_size)
)
self.out = nn.Linear(64, 10) #最後的全連線層
def forward(self, x,h):
# 輸入的批資料x shape (batch, time_step, input_size)(64,28,28)
# 輸出的資料,包括每一步的輸出r_out shape (batch, time_step, output_size)
# h_n shape (n_layers, batch, hidden_size)
# h_c shape (n_layers, batch, hidden_size)
# 每次的輸入包括輸入資料和之前的記憶資料
r_out, new_h = self.rnn(x, h)# 每次輸入到RNN裡面的都是當前的輸入和上一次的狀態資訊,初始的狀態資訊為None
# choose r_out at the last time step
out = self.out(r_out[:, -1, :]) ##只要最後一步的output
return out,new_h
rnn = RNN()
print(rnn)
optimizer = torch.optim.Adam(rnn.parameters(), lr=LR) # optimize all cnn parameters
loss_func = nn.CrossEntropyLoss() # the target label is not one-hotted
h=None
# training and testing
for epoch in range(EPOCH):
for step, (x, y) in enumerate(train_loader): # gives batch data
b_x = Variable(x.view(-1, 28, 28)) #x此時的shape(batch_size,1,28,28),x.size(0)=batch_size,x.view將其變為(batch_size*1,28,28)
b_y = Variable(y) # batch y
output,h_state = rnn(b_x,h) # rnn output
loss = loss_func(output, b_y) # cross entropy loss
optimizer.zero_grad() # clear gradients for this training step
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients
if step % 50 == 0:
test_output,new_h = rnn(test_x,h) #注意shape要和上面儲存一致 (samples, time_step, input_size)
pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
accuracy = sum(pred_y == test_y) / float(test_y.size)
print('Epoch: ', epoch, '| train loss: %.4f' % loss.data[0], '| test accuracy: %.2f' % accuracy)
# print 10 predictions from test data
test_output,new_h = rnn(test_x[:10].view(-1, 28, 28),h)
pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
print(pred_y, 'prediction number')
print(test_y[:10], 'real number')