keras RNN、LSTM對IMDB資料集進行分類
阿新 • • 發佈:2018-11-30
本文介紹如何基於keras採用RNN和LSTM對IMDB資料集進行分類。
示例程式碼:
from keras.layers import SimpleRNN from keras.models import Sequential from keras.layers import Embedding, SimpleRNN model = Sequential() model.add(Embedding(10000, 32)) model.add(SimpleRNN(32)) print(model.summary()) model = Sequential() model.add(Embedding(10000, 32)) model.add(SimpleRNN(32, return_sequences=True)) print(model.summary()) model = Sequential() model.add(Embedding(10000, 32)) model.add(SimpleRNN(32, return_sequences=True)) model.add(SimpleRNN(32, return_sequences=True)) model.add(SimpleRNN(32, return_sequences=True)) model.add(SimpleRNN(32)) print(model.summary()) from keras.datasets import imdb from keras.preprocessing import sequence max_features = 10000 maxlen = 500 batch_size = 32 print('Loading data......') (input_train, y_train), (input_test, y_test) = imdb.load_data(num_words=max_features) print(len(input_train), 'train sequences') print(len(input_test), 'test sequences') print('Pad Sequences (samples x time') input_train = sequence.pad_sequences(input_train, maxlen=maxlen) input_test = sequence.pad_sequences(input_test, maxlen=maxlen) print('input_train shape:', input_train.shape) print('input_test shape:', input_test.shape) from keras.layers import Dense model = Sequential() model.add(Embedding(max_features, 32)) model.add(SimpleRNN(32)) model.add(Dense(1, activation='sigmoid')) model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['acc']) hist = model.fit(input_train, y_train, epochs=10, batch_size=128, validation_split=0.2) import matplotlib.pyplot as plt acc = hist.history['acc'] val_acc = hist.history['val_acc'] loss = hist.history['loss'] val_loss = hist.history['val_loss'] epochs = range(len(acc)) plt.plot(epochs, acc, 'bo', label='Training acc') plt.plot(epochs, val_acc, 'b', label='Validation acc') plt.title('Training and validation accuracy') plt.legend() plt.figure() plt.plot(epochs, loss, 'bo', label='Training loss') plt.plot(epochs, val_loss, 'b', label='Validation loss') plt.title('Training and validation loss') plt.legend() plt.show() from keras.layers import LSTM model = Sequential() model.add(Embedding(max_features, 32)) model.add(LSTM(32)) model.add(Dense(1, activation='sigmoid')) model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['acc']) hist = model.fit(input_train, y_train, epochs=10, batch_size=128, validation_split=0.2) acc = hist.history['acc'] val_acc = hist.history['val_acc'] loss = hist.history['loss'] val_loss = hist.history['val_loss'] epochs = range(len(acc)) plt.plot(epochs, acc, 'bo', label='Training acc') plt.plot(epochs, val_acc, 'b', label='Validation acc') plt.title('Training and validation accuracy') plt.legend() plt.figure() plt.plot(epochs, loss, 'bo', label='Training loss') plt.plot(epochs, val_loss, 'b', label='Validation loss') plt.title('Training and validation loss') plt.legend() plt.show()
測試結果:
16128/20000 [=======================>......] - ETA: 5s - loss: 0.0187 - acc: 0.9954 16256/20000 [=======================>......] - ETA: 5s - loss: 0.0186 - acc: 0.9954 16384/20000 [=======================>......] - ETA: 4s - loss: 0.0186 - acc: 0.9954 16512/20000 [=======================>......] - ETA: 4s - loss: 0.0186 - acc: 0.9954 16640/20000 [=======================>......] - ETA: 4s - loss: 0.0185 - acc: 0.9954 16768/20000 [========================>.....] - ETA: 4s - loss: 0.0184 - acc: 0.9955 16896/20000 [========================>.....] - ETA: 4s - loss: 0.0184 - acc: 0.9955 17024/20000 [========================>.....] - ETA: 4s - loss: 0.0186 - acc: 0.9954 17152/20000 [========================>.....] - ETA: 3s - loss: 0.0189 - acc: 0.9953 17280/20000 [========================>.....] - ETA: 3s - loss: 0.0188 - acc: 0.9953 17408/20000 [=========================>....] - ETA: 3s - loss: 0.0189 - acc: 0.9952 17536/20000 [=========================>....] - ETA: 3s - loss: 0.0188 - acc: 0.9953 17664/20000 [=========================>....] - ETA: 3s - loss: 0.0187 - acc: 0.9953 17792/20000 [=========================>....] - ETA: 2s - loss: 0.0187 - acc: 0.9953 17920/20000 [=========================>....] - ETA: 2s - loss: 0.0186 - acc: 0.9953 18048/20000 [==========================>...] - ETA: 2s - loss: 0.0186 - acc: 0.9953 18176/20000 [==========================>...] - ETA: 2s - loss: 0.0185 - acc: 0.9954 18304/20000 [==========================>...] - ETA: 2s - loss: 0.0184 - acc: 0.9954 18432/20000 [==========================>...] - ETA: 2s - loss: 0.0185 - acc: 0.9954 18560/20000 [==========================>...] - ETA: 1s - loss: 0.0186 - acc: 0.9954 18688/20000 [===========================>..] - ETA: 1s - loss: 0.0185 - acc: 0.9954 18816/20000 [===========================>..] - ETA: 1s - loss: 0.0184 - acc: 0.9954 18944/20000 [===========================>..] - ETA: 1s - loss: 0.0184 - acc: 0.9955 19072/20000 [===========================>..] - ETA: 1s - loss: 0.0186 - acc: 0.9954 19200/20000 [===========================>..] - ETA: 1s - loss: 0.0188 - acc: 0.9953 19328/20000 [===========================>..] - ETA: 0s - loss: 0.0190 - acc: 0.9953 19456/20000 [============================>.] - ETA: 0s - loss: 0.0194 - acc: 0.9952 19584/20000 [============================>.] - ETA: 0s - loss: 0.0196 - acc: 0.9951 19712/20000 [============================>.] - ETA: 0s - loss: 0.0195 - acc: 0.9951 19840/20000 [============================>.] - ETA: 0s - loss: 0.0195 - acc: 0.9952 19968/20000 [============================>.] - ETA: 0s - loss: 0.0194 - acc: 0.9952 20000/20000 [==============================] - 29s 1ms/step - loss: 0.0194 - acc: 0.9952 - val_loss: 0.6177 - val_acc: 0.8292