1. 程式人生 > 實用技巧 >舉例深入理解Keras中LSTM的stateful和stateless應用區別

舉例深入理解Keras中LSTM的stateful和stateless應用區別

本文通過讓LSTM學習字母表,來預測下一個字母,詳細的請參考:

https://blog.csdn.net/zwqjoy/article/details/80493341

https://machinelearningmastery.com/understanding-stateful-lstm-recurrent-neural-networks-python-keras/

一、Stateful模式預測下一個字母

# Stateful LSTM to learn one-char to one-char mapping
import numpy
from keras.models import Sequential
from keras.layers import Dense from keras.layers import LSTM from keras.utils import np_utils # fix random seed for reproducibility numpy.random.seed(7) # define the raw dataset alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" # create mapping of characters to integers (0-25) and the reverse char_to_int = dict((c, i) for
i, c in enumerate(alphabet)) int_to_char = dict((i, c) for i, c in enumerate(alphabet)) # prepare the dataset of input to output pairs encoded as integers seq_length = 1 dataX = [] dataY = [] for i in range(0, len(alphabet) - seq_length, 1): seq_in = alphabet[i:i + seq_length] seq_out = alphabet[i + seq_length] dataX.append([char_to_int[char]
for char in seq_in]) dataY.append(char_to_int[seq_out]) print (seq_in, '->', seq_out) # reshape X to be [samples, time steps, features] X = numpy.reshape(dataX, (len(dataX), seq_length, 1)) # normalize X = X / float(len(alphabet)) # one hot encode the output variable y = np_utils.to_categorical(dataY) # create and fit the model batch_size = 1 model = Sequential() model.add(LSTM(16, batch_input_shape=(batch_size, X.shape[1], X.shape[2]), stateful=True)) model.add(Dense(y.shape[1], activation='softmax')) model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) for i in range(300): model.fit(X, y, epochs=1, batch_size=batch_size, verbose=2, shuffle=False) model.reset_states() # summarize performance of the model scores = model.evaluate(X, y, batch_size=batch_size, verbose=0) model.reset_states() print("Model Accuracy: %.2f%%" % (scores[1]*100))

OUT:

Model Accuracy: 100.00%

模型訓練後預測一下:

model.reset_states()#這個時候我們重置一下狀態,那麼就會從字母表的開頭開始
# demonstrate some model predictions
seed = [char_to_int[alphabet[0]]]
for i in range(0, len(alphabet)-1):
    x = numpy.reshape(seed, (1, len(seed), 1))
    x = x / float(len(alphabet))
    prediction = model.predict(x, verbose=0)
    index = numpy.argmax(prediction)
    print (int_to_char[seed[0]], "->", int_to_char[index])
    seed = [index]

OUT:

A -> B
B -> C
C -> D
D -> E
E -> F
F -> G
G -> H
H -> I
I -> J
J -> K
K -> L
L -> M
M -> N
N -> O
O -> P
P -> Q
Q -> R
R -> S
S -> T
T -> U
U -> V
V -> W
W -> X
X -> Y
Y -> Z

那麼如果我們從中間字母開始預測呢?
model.reset_states()#這個時候我們依然先重置一下狀態
# demonstrate a random starting point
letter = "K"
seed = [char_to_int[letter]]
print ("New start: ", letter)
for i in range(0, 5):
    x = numpy.reshape(seed, (1, len(seed), 1))
    x = x / float(len(alphabet))
    prediction = model.predict(x, verbose=0)
    index = numpy.argmax(prediction)
    print (int_to_char[seed[0]], "->", int_to_char[index])
    seed = [index]

OUT:

New start:  K
K -> B
B -> C
C -> D
D -> E
E -> F
我們可以看到,重置狀態後,即便是從中間的字母K開始預測,接下來輸出依然是從字母表開始輸出一樣輸出B,這說明前一個狀態的輸入Ct-1的作用是大於本次的輸入xt的
如果我們不重置狀態,直接從中間字母開始呢?
# demonstrate a random starting point
letter = "K"
seed = [char_to_int[letter]]
print ("New start: ", letter)
for i in range(0, 5):
    x = numpy.reshape(seed, (1, len(seed), 1))
    x = x / float(len(alphabet))
    prediction = model.predict(x, verbose=0)
    index = numpy.argmax(prediction)
    print (int_to_char[seed[0]], "->", int_to_char[index])
    seed = [index]

OUT:

New start:  K
K -> Z
Z -> Z
Z -> Z
Z -> Z
Z -> Z
我們可以看到,沒有重置狀態,直接預測,輸入的狀態依然是接著上一次的最後輸出狀態開始的,所以都預測成了Z,再次說明了上一次的狀態輸入其作用大於本次的輸入。