1. 程式人生 > >keras RNN、LSTM對IMDB資料集進行分類

keras RNN、LSTM對IMDB資料集進行分類

本文介紹如何基於keras採用RNN和LSTM對IMDB資料集進行分類。

示例程式碼:

from keras.layers import SimpleRNN
from keras.models import Sequential
from keras.layers import Embedding, SimpleRNN

model = Sequential()
model.add(Embedding(10000, 32))
model.add(SimpleRNN(32))
print(model.summary())

model = Sequential()
model.add(Embedding(10000, 32))
model.add(SimpleRNN(32, return_sequences=True))
print(model.summary())

model = Sequential()
model.add(Embedding(10000, 32))
model.add(SimpleRNN(32, return_sequences=True))
model.add(SimpleRNN(32, return_sequences=True))
model.add(SimpleRNN(32, return_sequences=True))
model.add(SimpleRNN(32))
print(model.summary())

from keras.datasets import imdb
from keras.preprocessing import sequence

max_features = 10000
maxlen = 500
batch_size = 32

print('Loading data......')
(input_train, y_train), (input_test, y_test) = imdb.load_data(num_words=max_features)
print(len(input_train), 'train sequences')
print(len(input_test), 'test sequences')

print('Pad Sequences (samples x time')
input_train = sequence.pad_sequences(input_train, maxlen=maxlen)
input_test = sequence.pad_sequences(input_test, maxlen=maxlen)
print('input_train shape:', input_train.shape)
print('input_test shape:', input_test.shape)

from keras.layers import Dense

model = Sequential()
model.add(Embedding(max_features, 32))
model.add(SimpleRNN(32))
model.add(Dense(1, activation='sigmoid'))

model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['acc'])
hist = model.fit(input_train, y_train,
                 epochs=10,
                 batch_size=128,
                 validation_split=0.2)

import matplotlib.pyplot as plt

acc = hist.history['acc']
val_acc = hist.history['val_acc']
loss = hist.history['loss']
val_loss = hist.history['val_loss']

epochs = range(len(acc))

plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.legend()
plt.figure()

plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.legend()

plt.show()


from keras.layers import LSTM

model = Sequential()
model.add(Embedding(max_features, 32))
model.add(LSTM(32))
model.add(Dense(1, activation='sigmoid'))

model.compile(optimizer='rmsprop',
              loss='binary_crossentropy',
              metrics=['acc'])
hist = model.fit(input_train, y_train,
                 epochs=10,
                 batch_size=128,
                 validation_split=0.2)

acc = hist.history['acc']
val_acc = hist.history['val_acc']
loss = hist.history['loss']
val_loss = hist.history['val_loss']

epochs = range(len(acc))

plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.legend()

plt.figure()

plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.legend()

plt.show()

測試結果:

16128/20000 [=======================>......] - ETA: 5s - loss: 0.0187 - acc: 0.9954
16256/20000 [=======================>......] - ETA: 5s - loss: 0.0186 - acc: 0.9954
16384/20000 [=======================>......] - ETA: 4s - loss: 0.0186 - acc: 0.9954
16512/20000 [=======================>......] - ETA: 4s - loss: 0.0186 - acc: 0.9954
16640/20000 [=======================>......] - ETA: 4s - loss: 0.0185 - acc: 0.9954
16768/20000 [========================>.....] - ETA: 4s - loss: 0.0184 - acc: 0.9955
16896/20000 [========================>.....] - ETA: 4s - loss: 0.0184 - acc: 0.9955
17024/20000 [========================>.....] - ETA: 4s - loss: 0.0186 - acc: 0.9954
17152/20000 [========================>.....] - ETA: 3s - loss: 0.0189 - acc: 0.9953
17280/20000 [========================>.....] - ETA: 3s - loss: 0.0188 - acc: 0.9953
17408/20000 [=========================>....] - ETA: 3s - loss: 0.0189 - acc: 0.9952
17536/20000 [=========================>....] - ETA: 3s - loss: 0.0188 - acc: 0.9953
17664/20000 [=========================>....] - ETA: 3s - loss: 0.0187 - acc: 0.9953
17792/20000 [=========================>....] - ETA: 2s - loss: 0.0187 - acc: 0.9953
17920/20000 [=========================>....] - ETA: 2s - loss: 0.0186 - acc: 0.9953
18048/20000 [==========================>...] - ETA: 2s - loss: 0.0186 - acc: 0.9953
18176/20000 [==========================>...] - ETA: 2s - loss: 0.0185 - acc: 0.9954
18304/20000 [==========================>...] - ETA: 2s - loss: 0.0184 - acc: 0.9954
18432/20000 [==========================>...] - ETA: 2s - loss: 0.0185 - acc: 0.9954
18560/20000 [==========================>...] - ETA: 1s - loss: 0.0186 - acc: 0.9954
18688/20000 [===========================>..] - ETA: 1s - loss: 0.0185 - acc: 0.9954
18816/20000 [===========================>..] - ETA: 1s - loss: 0.0184 - acc: 0.9954
18944/20000 [===========================>..] - ETA: 1s - loss: 0.0184 - acc: 0.9955
19072/20000 [===========================>..] - ETA: 1s - loss: 0.0186 - acc: 0.9954
19200/20000 [===========================>..] - ETA: 1s - loss: 0.0188 - acc: 0.9953
19328/20000 [===========================>..] - ETA: 0s - loss: 0.0190 - acc: 0.9953
19456/20000 [============================>.] - ETA: 0s - loss: 0.0194 - acc: 0.9952
19584/20000 [============================>.] - ETA: 0s - loss: 0.0196 - acc: 0.9951
19712/20000 [============================>.] - ETA: 0s - loss: 0.0195 - acc: 0.9951
19840/20000 [============================>.] - ETA: 0s - loss: 0.0195 - acc: 0.9952
19968/20000 [============================>.] - ETA: 0s - loss: 0.0194 - acc: 0.9952
20000/20000 [==============================] - 29s 1ms/step - loss: 0.0194 - acc: 0.9952 - val_loss: 0.6177 - val_acc: 0.8292